Skip to content

Commit

Permalink
Merge branch 'main' into JOSS-2023
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte authored Oct 25, 2023
2 parents 99baaab + 436091b commit 0694079
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ docs/_build
/dist
*egg-info
*test*.py
*test*.ipynb
/datasets
.vscode

Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand All @@ -27,7 +27,7 @@ repos:
args: ["--profile", "black"]

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.7.0
rev: 23.10.0
hooks:
- id: black-jupyter
language_version: python3
Expand Down
9 changes: 7 additions & 2 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
# Required
version: 2

# Set the OS, Python version and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.11"

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
Expand All @@ -13,6 +19,5 @@ sphinx:
formats: all

python:
version: 3.8
install:
- requirements: docs/requirements.txt
- requirements: docs/requirements.txt
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ______________________________________________________________________
<a href="#installation">Installation</a> •
<a href="#quickstart">Quickstart</a> •
<a href="#documentation">Documentation</a> •
<a href="#contributing">Contributing</a> •
<a href="#citation">Citation</a> •
<a href="#references">References</a>
</p>
Expand All @@ -21,7 +22,7 @@ ______________________________________________________________________



![CI](https://github.com/jet-net/jetnet/actions/workflows/ci.yml/badge.svg)
[![CI](https://github.com/jet-net/jetnet/actions/workflows/ci.yml/badge.svg)](https://github.com/jet-net/jetnet/actions)
[![Documentation Status](https://readthedocs.org/projects/jetnet/badge/?version=latest)](https://jetnet.readthedocs.io/en/latest/)
[![Codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/jet-net/JetNet/main.svg)](https://results.pre-commit.ci/latest/github/jet-net/JetNet/main)
Expand Down Expand Up @@ -73,7 +74,6 @@ pip install "jetnet[emdloss]"

Finally, [PyTorch Geometric](https://github.com/pyg-team/pytorch_geometric) must be installed independently for the Fréchet ParticleNet Distance metric `jetnet.evaluation.fpnd` ([Installation instructions](https://github.com/pyg-team/pytorch_geometric#installation)).


## Quickstart

Datasets can be downloaded and accessed quickly, for example:
Expand All @@ -83,7 +83,7 @@ from jetnet.datasets import JetNet, TopTagging
# as numpy arrays:
particle_data, jet_data = JetNet.getData(jet_type=["g", "q"], data_dir="./datasets/jetnet/")
# or as a PyTorch dataset:
dataset = TopTagging(jet_type="all", , data_dir="./datasets/toptagging/", split="train")
dataset = TopTagging(jet_type="all", data_dir="./datasets/toptagging/", split="train")
```

Evaluation metrics can be used as such:
Expand All @@ -103,9 +103,29 @@ loss.backward()

## Documentation

The full API reference and tutorials are available at [jetnet.readthedocs.io](https://jetnet.readthedocs.io/en/latest/). Tutorial notebooks are in the [tutorials](tutorials) folder, with more to come.
The full API reference and tutorials are available at [jetnet.readthedocs.io](https://jetnet.readthedocs.io/en/latest/). Tutorial notebooks are in the [tutorials](https://github.com/jet-net/JetNet/tree/main/tutorials) folder, with more to come.

<!-- More detailed information about each dataset can (or will) be found at [jet-net.github.io](https://jet-net.github.io/). -->

## Contributing

We welcome feedback and contributions! Please feel free to [create an issue](https://github.com/jet-net/JetNet/issues/new) for bugs or functionality requests, or open [pull requests](https://github.com/jet-net/JetNet/pulls) from your [forked repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo) to solve them.

### Building and testing locally

Perform an editable installation of the package from inside your forked repo and install the `pytest` package for unit testing:

More detailed information about each dataset can (or will) be found at [jet-net.github.io](https://jet-net.github.io/).
```bash
pip install -e .
pip install pytest
```

Run the test suite to ensure everything is working as expected:

```bash
pytest tests # tests all datasets
pytest tests -m "not slow" # tests only on the JetNet dataset for convenience
```

## Citation

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"autodocsumm",
"m2r2",
"nbsphinx",
"sphinx_rtd_theme",
]
autosummary_generate = True # Turn on sphinx.ext.autosummary

Expand Down
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ nbsphinx
numpy
readthedocs-sphinx-search
scipy
sphinx
sphinx_rtd_theme
sphinx<7
sphinx_rtd_theme==0.5.2
torch
tqdm
68 changes: 55 additions & 13 deletions jetnet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import annotations

import hashlib
import os
import sys
from os.path import exists
Expand Down Expand Up @@ -48,28 +49,69 @@ def download_progress_bar(file_url: str, file_dest: str):
sys.stdout.write("\n")


def checkDownloadZenodoDataset(data_dir: str, dataset_name: str, record_id: int, key: str):
"""Checks if dataset exists, if not downloads it from Zenodo, and returns the file path"""
file_path = f"{data_dir}/{key}"
if not exists(file_path):
os.system(f"mkdir -p {data_dir}")
file_url = getZenodoFileURL(record_id, key)
# from TorchVision
# https://github.com/pytorch/vision/blob/48f8473e21b0f3e425aabc60db201b68fedf59b3/torchvision/datasets/utils.py#L51-L66 # noqa: E501
def _calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but
# indicates that we are not using the MD5 checksum for cryptography. This enables its usage
# in restricted environments like FIPS.
if sys.version_info >= (3, 9):
md5 = hashlib.md5(usedforsecurity=False)
else:
md5 = hashlib.md5()
with open(fpath, "rb") as f:
# switch to simpler assignment operator once we support only Python >=3.8
# while chunk := f.read(chunk_size):
for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk)
return md5.hexdigest()

print(f"Downloading {dataset_name} dataset to {file_path}")
download_progress_bar(file_url, file_path)

return file_path
def _check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
fmd5 = _calculate_md5(fpath, **kwargs)
return (md5 == fmd5), fmd5


def getZenodoFileURL(record_id: int, file_name: str) -> str:
"""Finds URL for downloading the file ``file_name`` from a Zenodo record."""
def _getZenodoFileURL(record_id: int, file_name: str) -> str:
"""Finds URL and md5 hash for downloading the file ``file_name`` from a Zenodo record."""

import requests

records_url = f"https://zenodo.org/api/records/{record_id}"
r = requests.get(records_url).json()
file_url = next(item for item in r["files"] if item["key"] == file_name)["links"]["self"]
return file_url
file = next(item for item in r["files"] if item["filename"] == file_name)
file_url = file["links"]["download"]
md5 = file["checksum"]
return file_url, md5


def checkDownloadZenodoDataset(data_dir: str, dataset_name: str, record_id: int, key: str) -> str:
"""
Checks if dataset exists and md5 hash matches;
if not, downloads it from Zenodo, and returns the file path.
"""
file_path = f"{data_dir}/{key}"
file_url, md5 = _getZenodoFileURL(record_id, key)

if exists(file_path):
match_md5, fmd5 = _check_md5(file_path, md5)
if not match_md5:
print(
f"MD5 hash of {file_path} does not match "
f"(expected md5:{md5}, got md5:{fmd5}), "
"removing existing file and re-downloading. "
"Please open an issue at https://github.com/jet-net/JetNet/issues/new "
"if you believe the matching is failing incorrectly."
)
os.remove(file_path)

if not exists(file_path):
os.makedirs(data_dir, exist_ok=True)

print(f"Downloading {dataset_name} dataset to {file_path}")
download_progress_bar(file_url, file_path)

return file_path


def getOrderedFeatures(
Expand Down
8 changes: 5 additions & 3 deletions jetnet/evaluation/gen_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

rng = np.random.default_rng()

logger = logging.getLogger("jetnet")
logger.setLevel(logging.INFO)

# TODO: generic w1 method

Expand Down Expand Up @@ -80,7 +82,7 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
msg = (
"fid calculation produces singular product; " "adding %s to diagonal of cov estimates"
) % eps
logging.debug(msg)
logger.debug(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

Expand Down Expand Up @@ -152,7 +154,7 @@ def _get_fpnd_real_mu_sigma(
# run inference and store activations
jets_loaded = DataLoader(jets, batch_size)

logging.info(f"Calculating ParticleNet activations on real jets with batch size {batch_size}")
logger.info(f"Calculating ParticleNet activations on real jets with batch size {batch_size}")
activations = []
for i, jets_batch in _optional_tqdm(
enumerate(jets_loaded), use_tqdm, total=len(jets_loaded), desc="Running ParticleNet"
Expand Down Expand Up @@ -294,7 +296,7 @@ def fpnd(
# run inference and store activations
jets_loaded = DataLoader(jets[: _eval_module.fpnd_dict["NUM_SAMPLES"]], batch_size)

logging.info(f"Calculating ParticleNet activations with batch size: {batch_size}")
logger.info(f"Calculating ParticleNet activations with batch size: {batch_size}")
activations = []
for i, jets_batch in _optional_tqdm(
enumerate(jets_loaded), use_tqdm, total=len(jets_loaded), desc="Running ParticleNet"
Expand Down
8 changes: 7 additions & 1 deletion tests/datasets/test_jetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

data_dir = "./datasets/jetnet"
DataClass = JetNet
jet_types = ["g", "q"] # faster testing than using full dataset
jet_types = ["g", "q"] # subset of jet types
gq_length = 177252 + 170679


Expand All @@ -24,6 +24,12 @@
)
@pytest.mark.parametrize("num_particles", [30, 75])
def test_getData(jet_types, num_particles, expected_length, class_id):
# test md5 checksum is working for one of the datasets
if jet_types == "q":
# write random data to file
with open(f"{data_dir}/q{'150' if num_particles > 30 else ''}.hdf5", "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, num_particles=num_particles)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 5)
Expand Down
6 changes: 6 additions & 0 deletions tests/datasets/test_qgjets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
)
@pytest.mark.parametrize("file_list", [test_file_list_withbc, test_file_list_withoutbc])
def test_getData(jet_types, split, expected_length, class_id, file_list):
# test md5 checksum is working for one of the datasets
if jet_types == "q" and file_list == test_file_list_withoutbc:
# write random data to file
with open(data_dir + "/" + file_list[-1], "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, file_list=file_list, split=split)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 1)
Expand Down
6 changes: 6 additions & 0 deletions tests/datasets/test_toptagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
],
)
def test_getData(jet_types, split, expected_length, class_id):
# test md5 checksum is working for one of the datasets
if jet_types == "top" and split == "valid":
# write random data to file
with open(f"{data_dir}/val.h5", "wb") as f:
f.write(np.random.bytes(100))

pf, jf = DataClass.getData(jet_types, data_dir, split=split)
assert pf.shape == (expected_length, num_particles, 4)
assert jf.shape == (expected_length, 5)
Expand Down

0 comments on commit 0694079

Please sign in to comment.