Skip to content

Commit

Permalink
try catch for zenodo api and move tqdm import
Browse files Browse the repository at this point in the history
  • Loading branch information
rkansal47 committed Oct 26, 2023
1 parent 0c9e673 commit 3f09286
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
14 changes: 11 additions & 3 deletions jetnet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,17 @@ def _getZenodoFileURL(record_id: int, file_name: str) -> str:

records_url = f"https://zenodo.org/api/records/{record_id}"
r = requests.get(records_url).json()
file = next(item for item in r["files"] if item["filename"] == file_name)
file_url = file["links"]["download"]
md5 = file["checksum"]

# Zenodo API seems to be switching back and forth between these at the moment... so trying both
try:
file = next(item for item in r["files"] if item["filename"] == file_name)
file_url = file["links"]["download"]
md5 = file["checksum"]
except KeyError:
file = next(item for item in r["files"] if item["key"] == file_name)
file_url = file["links"]["self"]
md5 = file["checksum"].split("md5:")[1]

return file_url, md5


Expand Down
3 changes: 2 additions & 1 deletion jetnet/evaluation/gen_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from scipy.stats import iqr, wasserstein_distance
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from jetnet import utils
from jetnet.datasets import JetNet
Expand All @@ -40,6 +39,8 @@ def _check_get_ndarray(*arrs):

def _optional_tqdm(iter_obj, use_tqdm, total=None, desc=None):
if use_tqdm:
from tqdm import tqdm

return tqdm(iter_obj, total=total, desc=desc)
else:
return iter_obj
Expand Down

0 comments on commit 3f09286

Please sign in to comment.