diff --git a/ml4gw/transforms/__init__.py b/ml4gw/transforms/__init__.py index e70ba48..1b94a12 100644 --- a/ml4gw/transforms/__init__.py +++ b/ml4gw/transforms/__init__.py @@ -2,5 +2,6 @@ from .scaler import ChannelWiseScaler from .snr_rescaler import SnrRescaler from .spectral import SpectralDensity +from .spectrogram import MultiResolutionSpectrogram from .waveforms import WaveformProjector, WaveformSampler from .whitening import FixedWhiten, Whiten diff --git a/ml4gw/transforms/spectrogram.py b/ml4gw/transforms/spectrogram.py new file mode 100644 index 0000000..aa8a116 --- /dev/null +++ b/ml4gw/transforms/spectrogram.py @@ -0,0 +1,162 @@ +import warnings +from typing import Dict, List + +import torch +import torch.nn.functional as F +from torchaudio.transforms import Spectrogram + + +class MultiResolutionSpectrogram(torch.nn.Module): + """ + Create a batch of multi-resolution spectrograms + from a batch of timeseries. Input is expected to + have the shape `(B, C, T)`, where `B` is the number + of batches, `C` is the number of channels, and `T` + is the number of time samples. + + For each timeseries, calculate multiple normalized + spectrograms based on the `Spectrogram` `kwargs` given. + Combine the spectrograms by taking the maximum value + from the nearest time-frequncy bin. + + If the largest number of time bins among the spectrograms + is `N` and the largest number of frequency bins is `M`, + the output will have dimensions `(B, C, M, N)` + + Args: + kernel_length: + The length in seconds of the time dimension + of the tensor that will be turned into a + spectrogram + sample_rate: + The sample rate of the timeseries in Hz + kwargs: + Arguments passed in kwargs will used to create + `torchaudio.transforms.Spectrogram`s. Each + argument should be a list of values. Any list + of length greater than 1 should be the same + length + """ + + def __init__( + self, kernel_length: float, sample_rate: float, **kwargs + ) -> None: + super().__init__() + self.kernel_size = kernel_length * sample_rate + # This method of combination makes sense only when + # the spectrograms are normalized, so enforce this + if "normalized" in kwargs.keys(): + if not all(kwargs["normalized"]): + raise ValueError( + "Received a value of False for 'normalized'. " + "This method of combination is sensible only for " + "normalized spectrograms." + ) + else: + kwargs["normalized"] = [True] + self.kwargs = self._check_and_format_kwargs(kwargs) + + self.transforms = torch.nn.ModuleList( + [Spectrogram(**k) for k in self.kwargs] + ) + + dummy_input = torch.ones(int(kernel_length * sample_rate)) + self.shapes = torch.tensor( + [t(dummy_input).shape for t in self.transforms] + ) + + self.num_freqs = max([shape[0] for shape in self.shapes]) + self.num_times = max([shape[1] for shape in self.shapes]) + + left_pad = torch.zeros(len(self.transforms), dtype=torch.int) + top_pad = torch.zeros(len(self.transforms), dtype=torch.int) + bottom_pad = torch.tensor( + [int(self.num_freqs - shape[0]) for shape in self.shapes] + ) + right_pad = torch.tensor( + [int(self.num_times - shape[1]) for shape in self.shapes] + ) + self.register_buffer("left_pad", left_pad) + self.register_buffer("top_pad", top_pad) + self.register_buffer("bottom_pad", bottom_pad) + self.register_buffer("right_pad", right_pad) + + freq_idxs = torch.tensor( + [ + [int(i * shape[0] / self.num_freqs) for shape in self.shapes] + for i in range(self.num_freqs) + ] + ) + freq_idxs = freq_idxs.repeat(self.num_times, 1, 1).transpose(0, 1) + time_idxs = torch.tensor( + [ + [int(i * shape[1] / self.num_times) for shape in self.shapes] + for i in range(self.num_times) + ] + ) + time_idxs = time_idxs.repeat(self.num_freqs, 1, 1) + + self.register_buffer("freq_idxs", freq_idxs) + self.register_buffer("time_idxs", time_idxs) + + def _check_and_format_kwargs(self, kwargs: Dict[str, List]) -> List: + lengths = sorted(set([len(v) for v in kwargs.values()])) + + if lengths[-1] > 3: + warnings.warn( + "Combining too many spectrograms can impede computation time. " + "If performance is slower than desired, try reducing the " + "number of spectrograms", + RuntimeWarning, + ) + + if len(lengths) > 2 or (len(lengths) == 2 and lengths[0] != 1): + raise ValueError( + "Spectrogram keyword args should all have the same " + f"length or be of length one. Got lengths {lengths}" + ) + + if len(lengths) == 2: + size = lengths[1] + kwargs = {k: v * int(size / len(v)) for k, v in kwargs.items()} + + return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())] + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Calculate spectrograms of the input tensor and + combine them into a single spectrogram + + Args: + X: + Batch of multichannel timeseries which will + be used to calculate the multi-resolution + spectrogram. Should have the shape + `(B, C, T)`, where `B` is the number of + batches, `C` is the number of channels, + and `T` is the number of time samples. + """ + if X.shape[-1] != self.kernel_size: + raise ValueError( + "Expected time dimension to be " + f"{self.kernel_size} samples long, got input with " + f"{X.shape[-1]} samples" + ) + + spectrograms = [t(X) for t in self.transforms] + + padded_specs = [] + for spec, left, right, top, bottom in zip( + spectrograms, + self.left_pad, + self.right_pad, + self.top_pad, + self.bottom_pad, + ): + padded_specs.append(F.pad(spec, (left, right, top, bottom))) + + padded_specs = torch.stack(padded_specs) + remapped_specs = padded_specs[..., self.freq_idxs, self.time_idxs] + remapped_specs = torch.diagonal(remapped_specs, dim1=0, dim2=-1) + + return torch.max(remapped_specs, axis=-1)[0] diff --git a/poetry.lock b/poetry.lock index 47eb26f..8d566a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -398,14 +398,15 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "dill" -version = "0.3.7" +version = "0.3.8" description = "serialize all of Python" category = "dev" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" [package.extras] graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] [[package]] name = "distlib" @@ -475,7 +476,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.13.1" description = "A platform independent file lock." -category = "dev" +category = "main" optional = false python-versions = ">=3.8" @@ -514,6 +515,38 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" +[[package]] +name = "fsspec" +version = "2023.12.2" +description = "File-system specification" +category = "main" +optional = false +python-versions = ">=3.8" + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +devel = ["pytest", "pytest-cov"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +tqdm = ["tqdm"] + [[package]] name = "h5py" version = "3.10.0" @@ -698,7 +731,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.3" description = "A very fast and expressive template engine." -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -1054,7 +1087,7 @@ six = "*" name = "markupsafe" version = "2.1.4" description = "Safely add untrusted strings to HTML/XML markup." -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -1097,6 +1130,20 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + [[package]] name = "nbclient" version = "0.9.0" @@ -1177,6 +1224,21 @@ category = "dev" optional = false python-versions = ">=3.5" +[[package]] +name = "networkx" +version = "3.1" +description = "Python package for creating and manipulating graphs and networks" +category = "main" +optional = false +python-versions = ">=3.8" + +[package.extras] +default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] +developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] +doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] +test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] + [[package]] name = "nodeenv" version = "1.8.0" @@ -1231,56 +1293,115 @@ optional = false python-versions = ">=3.8" [[package]] -name = "nvidia-cublas-cu11" -version = "11.10.3.66" +name = "nvidia-cublas-cu12" +version = "12.1.3.1" description = "CUBLAS native runtime libraries" category = "main" optional = false python-versions = ">=3" -[package.dependencies] -setuptools = "*" -wheel = "*" +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +category = "main" +optional = false +python-versions = ">=3" [[package]] -name = "nvidia-cuda-nvrtc-cu11" -version = "11.7.99" +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" description = "NVRTC native runtime libraries" category = "main" optional = false python-versions = ">=3" +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +category = "main" +optional = false +python-versions = ">=3" + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +description = "cuDNN runtime libraries" +category = "main" +optional = false +python-versions = ">=3" + [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-cublas-cu12 = "*" [[package]] -name = "nvidia-cuda-runtime-cu11" -version = "11.7.99" -description = "CUDA Runtime native Libraries" +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +category = "main" +optional = false +python-versions = ">=3" + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +category = "main" +optional = false +python-versions = ">=3" + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" category = "main" optional = false python-versions = ">=3" [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" [[package]] -name = "nvidia-cudnn-cu11" -version = "8.5.0.96" -description = "cuDNN runtime libraries" +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" category = "main" optional = false python-versions = ">=3" [package.dependencies] -setuptools = "*" -wheel = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.18.1" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +category = "main" +optional = false +python-versions = ">=3" + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.3.101" +description = "Nvidia JIT LTO Library" +category = "main" +optional = false +python-versions = ">=3" + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +category = "main" +optional = false +python-versions = ">=3" [[package]] name = "overrides" -version = "7.6.0" +version = "7.7.0" description = "A decorator to automatically detect mismatch when overriding a method." category = "dev" optional = false @@ -1598,7 +1719,7 @@ python-versions = ">=3.6" [[package]] name = "pytz" -version = "2023.3.post1" +version = "2023.4" description = "World timezone definitions, modern and historical" category = "dev" optional = false @@ -1677,7 +1798,7 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "referencing" -version = "0.32.1" +version = "0.33.0" description = "JSON Referencing + Python" category = "dev" optional = false @@ -1765,7 +1886,7 @@ win32 = ["pywin32"] name = "setuptools" version = "69.0.3" description = "Easily download, build, install, upgrade, and uninstall Python packages" -category = "main" +category = "dev" optional = false python-versions = ">=3.8" @@ -1814,6 +1935,17 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "sympy" +version = "1.12" +description = "Computer algebra system (CAS) in Python" +category = "main" +optional = false +python-versions = ">=3.8" + +[package.dependencies] +mpmath = ">=0.19" + [[package]] name = "terminado" version = "0.18.0" @@ -1857,22 +1989,47 @@ python-versions = ">=3.7" [[package]] name = "torch" -version = "1.13.1" +version = "2.1.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" category = "main" optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" [package.dependencies] -nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""} -nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""} -nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""} -nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""} +filelock = "*" +fsspec = "*" +jinja2 = "*" +networkx = "*" +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +sympy = "*" +triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} typing-extensions = "*" [package.extras] +dynamo = ["jinja2"] opt-einsum = ["opt-einsum (>=3.3)"] +[[package]] +name = "torchaudio" +version = "2.1.2" +description = "An audio package for PyTorch" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +torch = "2.1.2" + [[package]] name = "torchtyping" version = "0.1.4" @@ -1922,6 +2079,22 @@ python-versions = ">=3.8" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "triton" +version = "2.1.0" +description = "A language and compiler for custom Deep Learning operations" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.18)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + [[package]] name = "typeguard" version = "4.1.5" @@ -2044,17 +2217,6 @@ docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"] optional = ["python-socks", "wsaccel"] test = ["websockets"] -[[package]] -name = "wheel" -version = "0.42.0" -description = "A built-package format for Python" -category = "main" -optional = false -python-versions = ">=3.7" - -[package.extras] -test = ["pytest (>=6.0.0)", "setuptools (>=65)"] - [[package]] name = "widgetsnbextension" version = "4.0.9" @@ -2078,7 +2240,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "783d8b4dab815aafca222d3f1ce4383fa13d5860227ce66c4732dbcf3287ecf1" +content-hash = "7872f807d45c9daa8449e007cbb62c77765838fc7e9660e12d622a8a3ec35a72" [metadata.files] anyio = [ @@ -2369,6 +2531,7 @@ contourpy = [ {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18a64814ae7bce73925131381603fff0116e2df25230dfc80d6d690aa6e20b37"}, {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c81f22b4f572f8a2110b0b741bb64e5a6427e0a198b2cdc1fbaf85f352a3aa"}, {file = "contourpy-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53cc3a40635abedbec7f1bde60f8c189c49e84ac180c665f2cd7c162cc454baa"}, + {file = "contourpy-1.1.0-cp310-cp310-win32.whl", hash = "sha256:9b2dd2ca3ac561aceef4c7c13ba654aaa404cf885b187427760d7f7d4c57cff8"}, {file = "contourpy-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:1f795597073b09d631782e7245016a4323cf1cf0b4e06eef7ea6627e06a37ff2"}, {file = "contourpy-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0b7b04ed0961647691cfe5d82115dd072af7ce8846d31a5fac6c142dcce8b882"}, {file = "contourpy-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27bc79200c742f9746d7dd51a734ee326a292d77e7d94c8af6e08d1e6c15d545"}, @@ -2377,6 +2540,7 @@ contourpy = [ {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5cec36c5090e75a9ac9dbd0ff4a8cf7cecd60f1b6dc23a374c7d980a1cd710e"}, {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f0cbd657e9bde94cd0e33aa7df94fb73c1ab7799378d3b3f902eb8eb2e04a3a"}, {file = "contourpy-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:181cbace49874f4358e2929aaf7ba84006acb76694102e88dd15af861996c16e"}, + {file = "contourpy-1.1.0-cp311-cp311-win32.whl", hash = "sha256:edb989d31065b1acef3828a3688f88b2abb799a7db891c9e282df5ec7e46221b"}, {file = "contourpy-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fb3b7d9e6243bfa1efb93ccfe64ec610d85cfe5aec2c25f97fbbd2e58b531256"}, {file = "contourpy-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcb41692aa09aeb19c7c213411854402f29f6613845ad2453d30bf421fe68fed"}, {file = "contourpy-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5d123a5bc63cd34c27ff9c7ac1cd978909e9c71da12e05be0231c608048bb2ae"}, @@ -2385,6 +2549,7 @@ contourpy = [ {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:317267d915490d1e84577924bd61ba71bf8681a30e0d6c545f577363157e5e94"}, {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d551f3a442655f3dcc1285723f9acd646ca5858834efeab4598d706206b09c9f"}, {file = "contourpy-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e7a117ce7df5a938fe035cad481b0189049e8d92433b4b33aa7fc609344aafa1"}, + {file = "contourpy-1.1.0-cp38-cp38-win32.whl", hash = "sha256:108dfb5b3e731046a96c60bdc46a1a0ebee0760418951abecbe0fc07b5b93b27"}, {file = "contourpy-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4f26b25b4f86087e7d75e63212756c38546e70f2a92d2be44f80114826e1cd4"}, {file = "contourpy-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc00bb4225d57bff7ebb634646c0ee2a1298402ec10a5fe7af79df9a51c1bfd9"}, {file = "contourpy-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:189ceb1525eb0655ab8487a9a9c41f42a73ba52d6789754788d1883fb06b2d8a"}, @@ -2393,6 +2558,7 @@ contourpy = [ {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:143dde50520a9f90e4a2703f367cf8ec96a73042b72e68fcd184e1279962eb6f"}, {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e94bef2580e25b5fdb183bf98a2faa2adc5b638736b2c0a4da98691da641316a"}, {file = "contourpy-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ed614aea8462735e7d70141374bd7650afd1c3f3cb0c2dbbcbe44e14331bf002"}, + {file = "contourpy-1.1.0-cp39-cp39-win32.whl", hash = "sha256:71551f9520f008b2950bef5f16b0e3587506ef4f23c734b71ffb7b89f8721999"}, {file = "contourpy-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:438ba416d02f82b692e371858143970ed2eb6337d9cdbbede0d8ad9f3d7dd17d"}, {file = "contourpy-1.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a698c6a7a432789e587168573a864a7ea374c6be8d4f31f9d87c001d5a843493"}, {file = "contourpy-1.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397b0ac8a12880412da3551a8cb5a187d3298a72802b45a3bd1805e204ad8439"}, @@ -2525,8 +2691,8 @@ defusedxml = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] dill = [ - {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, - {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, ] distlib = [ {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, @@ -2604,6 +2770,10 @@ fqdn = [ {file = "fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014"}, {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +fsspec = [ + {file = "fsspec-2023.12.2-py3-none-any.whl", hash = "sha256:d800d87f72189a745fa3d6b033b9dc4a34ad069f60ca60b943a63599f5501960"}, + {file = "fsspec-2023.12.2.tar.gz", hash = "sha256:8548d39e8810b59c38014934f6b31e57f40c1b20f911f4cc2b85389c7e9bf0cb"}, +] h5py = [ {file = "h5py-3.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b963fb772964fc1d1563c57e4e2e874022ce11f75ddc6df1a626f42bd49ab99f"}, {file = "h5py-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:012ab448590e3c4f5a8dd0f3533255bc57f80629bf7c5054cf4c87b30085063c"}, @@ -2868,6 +3038,7 @@ ligo-segments = [ {file = "ligo-segments-1.4.0.tar.gz", hash = "sha256:e072a844713c5b02efdcaf5bfe4c3a8cd9ef225b08cfd3202a4e185e0f71f5dc"}, {file = "ligo_segments-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:f8a0fb3ec6bef7effb80fb2f4f4ffaac02afa77db41a0c42fce5e969201e621b"}, {file = "ligo_segments-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:da82ae2b839bad0027e0a492b3673344e62360747d41194a265563e6c2903c7d"}, + {file = "ligo_segments-1.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:601be6d92e52bdebbb5a82b608ed1ceb781a0ea86b3e5333d61af77575b32664"}, {file = "ligo_segments-1.4.0-cp36-cp36m-win_amd64.whl", hash = "sha256:cde0aa9a6cf9d1b66b00d64623a026ed9d37dbc028a9769492c635cfa726fdb9"}, {file = "ligo_segments-1.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:5ce03d6c0d07a14cde9e3a47ecc44bd0efa07a7d58e75a8899be8d35a437b254"}, {file = "ligo_segments-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:1ec7b6dae3e7dbe2e96dc7dba39a843aab9a4df202ad6b80ce708d7439b7480e"}, @@ -2995,6 +3166,10 @@ mistune = [ {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, ] +mpmath = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] nbclient = [ {file = "nbclient-0.9.0-py3-none-any.whl", hash = "sha256:a3a1ddfb34d4a9d17fc744d655962714a866639acd30130e9be84191cd97cd15"}, {file = "nbclient-0.9.0.tar.gz", hash = "sha256:4b28c207877cf33ef3a9838cdc7a54c5ceff981194a82eac59d558f05487295e"}, @@ -3011,6 +3186,10 @@ nest-asyncio = [ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, ] +networkx = [ + {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, + {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, +] nodeenv = [ {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, @@ -3053,26 +3232,55 @@ numpy = [ {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] -nvidia-cublas-cu11 = [ - {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl", hash = "sha256:d32e4d75f94ddfb93ea0a5dda08389bcc65d8916a25cb9f37ac89edaeed3bded"}, - {file = "nvidia_cublas_cu11-11.10.3.66-py3-none-win_amd64.whl", hash = "sha256:8ac17ba6ade3ed56ab898a036f9ae0756f1e81052a317bf98f8c6d18dc3ae49e"}, +nvidia-cublas-cu12 = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] +nvidia-cuda-cupti-cu12 = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] +nvidia-cuda-nvrtc-cu12 = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] +nvidia-cuda-runtime-cu12 = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, ] -nvidia-cuda-nvrtc-cu11 = [ - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:9f1562822ea264b7e34ed5930567e89242d266448e936b85bc97a3370feabb03"}, - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:f7d9610d9b7c331fa0da2d1b2858a4a8315e6d49765091d28711c8946e7425e7"}, - {file = "nvidia_cuda_nvrtc_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:f2effeb1309bdd1b3854fc9b17eaf997808f8b25968ce0c7070945c4265d64a3"}, +nvidia-cudnn-cu12 = [ + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, ] -nvidia-cuda-runtime-cu11 = [ - {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl", hash = "sha256:cc768314ae58d2641f07eac350f40f99dcb35719c4faff4bc458a7cd2b119e31"}, - {file = "nvidia_cuda_runtime_cu11-11.7.99-py3-none-win_amd64.whl", hash = "sha256:bc77fa59a7679310df9d5c70ab13c4e34c64ae2124dd1efd7e5474b71be125c7"}, +nvidia-cufft-cu12 = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, ] -nvidia-cudnn-cu11 = [ - {file = "nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl", hash = "sha256:402f40adfc6f418f9dae9ab402e773cfed9beae52333f6d86ae3107a1b9527e7"}, - {file = "nvidia_cudnn_cu11-8.5.0.96-py3-none-manylinux1_x86_64.whl", hash = "sha256:71f8111eb830879ff2836db3cccf03bbd735df9b0d17cd93761732ac50a8a108"}, +nvidia-curand-cu12 = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] +nvidia-cusolver-cu12 = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] +nvidia-cusparse-cu12 = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] +nvidia-nccl-cu12 = [ + {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, +] +nvidia-nvjitlink-cu12 = [ + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, +] +nvidia-nvtx-cu12 = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] overrides = [ - {file = "overrides-7.6.0-py3-none-any.whl", hash = "sha256:c36e6635519ea9c5b043b65c36d4b886aee8bd45b7d4681d2a6df0898df4b654"}, - {file = "overrides-7.6.0.tar.gz", hash = "sha256:01e15bbbf15b766f0675c275baa1878bd1c7dc9bc7b9ee13e677cdba93dc1bd9"}, + {file = "overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49"}, + {file = "overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a"}, ] packaging = [ {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, @@ -3358,8 +3566,8 @@ python-json-logger = [ {file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"}, ] pytz = [ - {file = "pytz-2023.3.post1-py2.py3-none-any.whl", hash = "sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7"}, - {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"}, + {file = "pytz-2023.4-py2.py3-none-any.whl", hash = "sha256:f90ef520d95e7c46951105338d918664ebfd6f1d995bd7d153127ce90efafa6a"}, + {file = "pytz-2023.4.tar.gz", hash = "sha256:31d4583c4ed539cd037956140d695e42c033a19e984bfce9964a3f7d59bc2b40"}, ] pywin32 = [ {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, @@ -3391,6 +3599,7 @@ pyyaml = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3398,8 +3607,15 @@ pyyaml = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3416,6 +3632,7 @@ pyyaml = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3423,6 +3640,7 @@ pyyaml = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3531,8 +3749,8 @@ qtpy = [ {file = "QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"}, ] referencing = [ - {file = "referencing-0.32.1-py3-none-any.whl", hash = "sha256:7e4dc12271d8e15612bfe35792f5ea1c40970dadf8624602e33db2758f7ee554"}, - {file = "referencing-0.32.1.tar.gz", hash = "sha256:3c57da0513e9563eb7e203ebe9bb3a1b509b042016433bd1e45a2853466c3dd3"}, + {file = "referencing-0.33.0-py3-none-any.whl", hash = "sha256:39240f2ecc770258f28b642dd47fd74bc8b02484de54e1882b74b35ebd779bd5"}, + {file = "referencing-0.33.0.tar.gz", hash = "sha256:c775fedf74bc0f9189c2a3be1c12fd03e8c23f4d371dce795df44e06c5b412f7"}, ] requests = [ {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, @@ -3694,6 +3912,10 @@ stack-data = [ {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, ] +sympy = [ + {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, + {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, +] terminado = [ {file = "terminado-0.18.0-py3-none-any.whl", hash = "sha256:87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e"}, {file = "terminado-0.18.0.tar.gz", hash = "sha256:1ea08a89b835dd1b8c0c900d92848147cef2537243361b2e3f4dc15df9b6fded"}, @@ -3707,27 +3929,48 @@ tomli = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] torch = [ - {file = "torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"}, - {file = "torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d9fe785d375f2e26a5d5eba5de91f89e6a3be5d11efb497e76705fdf93fa3c2e"}, - {file = "torch-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:98124598cdff4c287dbf50f53fb455f0c1e3a88022b39648102957f3445e9b76"}, - {file = "torch-1.13.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:393a6273c832e047581063fb74335ff50b4c566217019cc6ace318cd79eb0566"}, - {file = "torch-1.13.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:0122806b111b949d21fa1a5f9764d1fd2fcc4a47cb7f8ff914204fd4fc752ed5"}, - {file = "torch-1.13.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:22128502fd8f5b25ac1cd849ecb64a418382ae81dd4ce2b5cebaa09ab15b0d9b"}, - {file = "torch-1.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:76024be052b659ac1304ab8475ab03ea0a12124c3e7626282c9c86798ac7bc11"}, - {file = "torch-1.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ea8dda84d796094eb8709df0fcd6b56dc20b58fdd6bc4e8d7109930dafc8e419"}, - {file = "torch-1.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:2ee7b81e9c457252bddd7d3da66fb1f619a5d12c24d7074de91c4ddafb832c93"}, - {file = "torch-1.13.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:0d9b8061048cfb78e675b9d2ea8503bfe30db43d583599ae8626b1263a0c1380"}, - {file = "torch-1.13.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:f402ca80b66e9fbd661ed4287d7553f7f3899d9ab54bf5c67faada1555abde28"}, - {file = "torch-1.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:727dbf00e2cf858052364c0e2a496684b9cb5aa01dc8a8bc8bbb7c54502bdcdd"}, - {file = "torch-1.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:df8434b0695e9ceb8cc70650afc1310d8ba949e6db2a0525ddd9c3b2b181e5fe"}, - {file = "torch-1.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:5e1e722a41f52a3f26f0c4fcec227e02c6c42f7c094f32e49d4beef7d1e213ea"}, - {file = "torch-1.13.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:33e67eea526e0bbb9151263e65417a9ef2d8fa53cbe628e87310060c9dcfa312"}, - {file = "torch-1.13.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:eeeb204d30fd40af6a2d80879b46a7efbe3cf43cdbeb8838dd4f3d126cc90b2b"}, - {file = "torch-1.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:50ff5e76d70074f6653d191fe4f6a42fdbe0cf942fbe2a3af0b75eaa414ac038"}, - {file = "torch-1.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:2c3581a3fd81eb1f0f22997cddffea569fea53bafa372b2c0471db373b26aafc"}, - {file = "torch-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0aa46f0ac95050c604bcf9ef71da9f1172e5037fdf2ebe051962d47b123848e7"}, - {file = "torch-1.13.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6930791efa8757cb6974af73d4996b6b50c592882a324b8fb0589c6a9ba2ddaf"}, - {file = "torch-1.13.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:e0df902a7c7dd6c795698532ee5970ce898672625635d885eade9976e5a04949"}, + {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, + {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, + {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, + {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, + {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, + {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, + {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, + {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, + {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, + {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, + {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, + {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, + {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, + {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, + {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, + {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, + {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, + {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, + {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, + {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, +] +torchaudio = [ + {file = "torchaudio-2.1.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:06f8c02814e6cdd78626bbf44ad2bb8afa5b39ab650c6af18328a32311461058"}, + {file = "torchaudio-2.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9d676673c1ce4dd11fca145e3a6cd9b4e5b897cffad0f617d2906f2d3fc8c3a9"}, + {file = "torchaudio-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:23c1b34e98664a0ac239efd4e1a0af407b3dd0a86a5869114bae582c3e5437d7"}, + {file = "torchaudio-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f82657fc4ec3b473bf6c752c0ee62d7f511af9ef37e5143f8339ec049504d767"}, + {file = "torchaudio-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:683eaa721e016ca1f27bb28fa89feae37a6f7b98ff1ceee0d5e5aedd19bd982c"}, + {file = "torchaudio-2.1.2-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:c1084eedf4ced1af9fdd18910690ff615f89baeb30b32030806543fbc6f3657e"}, + {file = "torchaudio-2.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:860acc32e6507063f2c13d81e26718199e215f34a2bcd6c9609a25e9bf21aa36"}, + {file = "torchaudio-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:50c651d60bde7a4e096bf376eddb9ea32da6e37c3827536d6e918798ad203dbf"}, + {file = "torchaudio-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:9a6bade8a2495a724f4ee6acb5e86828ff4083dc6c7c57c6386b54a0ea7afe71"}, + {file = "torchaudio-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:47f322708c282e0b1b7548cdbe4e12451c531061761885d7c7fe2e479a4a3861"}, + {file = "torchaudio-2.1.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:d0efd008c35dec962b80f5dce3468bd1b88301cf65152bbfa7f74c0005a17e89"}, + {file = "torchaudio-2.1.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:30ad97112412592518953f3cc2cd1b6ae153d6563dd5bd9eab6a972315fe9d9e"}, + {file = "torchaudio-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c33e05c2305bc4d659aaf77a385433e3f8ac07ae235d3b15d6ef4ff995258746"}, + {file = "torchaudio-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:ff7156b30eb05e9124286c30c80da84b93e227d009adb96eb19489600b459332"}, + {file = "torchaudio-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:dbaefae9ca0b208ce0157e0358cea8ab796c9e26a2c61c3d181246e4010b04d2"}, + {file = "torchaudio-2.1.2-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:ae50dcf34d5c6f73180cf694195ee31194b9d6090328575c30a5960bc716fa52"}, + {file = "torchaudio-2.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:14729cc9df52defa674fcf5ed4de0d6507038ef18012b96a2f56a77ed70676dd"}, + {file = "torchaudio-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:03b4cf02ee468b25280f9593cca95a32b517a88512a1e5f41129e24cd0c17e64"}, + {file = "torchaudio-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b3bbb5324e705ded77616e546591b249ae7588d35a3e8c2c4c1d986a5ea51ef4"}, + {file = "torchaudio-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:03d2a3c9a806486f2d9646381a564a922a880b6df8f18336b6f0e4a0d8356743"}, ] torchtyping = [ {file = "torchtyping-0.1.4-py3-none-any.whl", hash = "sha256:485fb6ef3965c39b0de15f00d6f49373e0a3a6993e9733942a63c5e207d35390"}, @@ -3754,6 +3997,16 @@ traitlets = [ {file = "traitlets-5.14.1-py3-none-any.whl", hash = "sha256:2e5a030e6eff91737c643231bfcf04a65b0132078dad75e4936700b213652e74"}, {file = "traitlets-5.14.1.tar.gz", hash = "sha256:8585105b371a04b8316a43d5ce29c098575c2e477850b62b848b964f1444527e"}, ] +triton = [ + {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, + {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, + {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, + {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, + {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, + {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, + {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, + {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, +] typeguard = [ {file = "typeguard-4.1.5-py3-none-any.whl", hash = "sha256:8923e55f8873caec136c892c3bed1f676eae7be57cdb94819281b3d3bc9c0953"}, {file = "typeguard-4.1.5.tar.gz", hash = "sha256:ea0a113bbc111bcffc90789ebb215625c963411f7096a7e9062d4e4630c155fd"}, @@ -3798,10 +4051,6 @@ websocket-client = [ {file = "websocket-client-1.7.0.tar.gz", hash = "sha256:10e511ea3a8c744631d3bd77e61eb17ed09304c413ad42cf6ddfa4c7787e8fe6"}, {file = "websocket_client-1.7.0-py3-none-any.whl", hash = "sha256:f4c3d22fec12a2461427a29957ff07d35098ee2d976d3ba244e688b8b4057588"}, ] -wheel = [ - {file = "wheel-0.42.0-py3-none-any.whl", hash = "sha256:177f9c9b0d45c47873b619f5b650346d632cdc35fb5e4d25058e09c9e581433d"}, - {file = "wheel-0.42.0.tar.gz", hash = "sha256:c45be39f7882c9d34243236f2d63cbd58039e360f85d0913425fbd7ceea617a8"}, -] widgetsnbextension = [ {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"}, {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, diff --git a/pyproject.toml b/pyproject.toml index 71054c6..ca8eb9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,14 +4,15 @@ version = "0.3.0" description = "Tools for training torch models on gravitational wave data" readme = "README.md" authors = [ - "Alec Gunny ", "Ethan Marx " + "Alec Gunny ", "Ethan Marx ", "Will Benoit " ] [tool.poetry.dependencies] python = "^3.8" # torch deps -torch = "^1.10" +torch = "^2.0" +torchaudio = "^2.0" torchtyping = "^0.1" [tool.poetry.group.dev.dependencies] diff --git a/tests/transforms/test_spectrogram.py b/tests/transforms/test_spectrogram.py new file mode 100644 index 0000000..91c6753 --- /dev/null +++ b/tests/transforms/test_spectrogram.py @@ -0,0 +1,110 @@ +import pytest +import torch +from torchaudio.transforms import Spectrogram + +from ml4gw.transforms import MultiResolutionSpectrogram + + +@pytest.fixture(params=[2, 4, 5]) +def kernel_length(request): + return request.param + + +@pytest.fixture(params=[128, 256]) +def sample_rate(request): + return request.param + + +@pytest.fixture(params=[1, 10]) +def batch_size(request): + return request.param + + +@pytest.fixture(params=[1, 3]) +def num_channels(request): + return request.param + + +@pytest.fixture(params=[[64], [64, 128], [64, 128, 256]]) +def n_ffts(request): + return request.param + + +@pytest.fixture(params=[[50]]) +def win_lengths(request): + return request.param + + +@pytest.fixture(params=[[2]]) +def powers(request): + return request.param + + +def test_num_spectrogram_warning(): + with pytest.warns(RuntimeWarning): + MultiResolutionSpectrogram(8, 2048, n_fft=[32, 64, 128, 256]) + + +def test_normalization_error(): + with pytest.raises(ValueError): + MultiResolutionSpectrogram(8, 2048, normalized=[False]) + + +def test_arg_length_error(): + with pytest.raises(ValueError): + MultiResolutionSpectrogram(8, 2048, n_fft=[32, 64], power=[1, 1, 1]) + + +def test_kernel_size_error(): + with pytest.raises(ValueError): + spectrogram = MultiResolutionSpectrogram(8, 2048) + spectrogram(torch.ones([4, 3, 8 * 2048 + 1])) + + +def test_multi_resolution_spectrogram( + kernel_length, + sample_rate, + batch_size, + num_channels, + n_ffts, + win_lengths, + powers, +): + + # Creating a MRS without any spectrogram arguments should + # just create a single default torchaudio histogram with + # `normalized = True` + spectrogram = MultiResolutionSpectrogram(kernel_length, sample_rate) + + X = torch.ones([batch_size, num_channels, kernel_length * sample_rate]) + y = spectrogram(X) + expected_y = Spectrogram(normalized=True)(X) + + assert (y == expected_y).all() + + # Check that all the indexing we're doing is working by + # performing a more explicit version of the algorithm + spectrogram = MultiResolutionSpectrogram( + kernel_length, + sample_rate, + n_fft=n_ffts, + win_length=win_lengths, + power=powers, + ) + y = spectrogram(X) + kwargs = spectrogram.kwargs + ta_spectrograms = [Spectrogram(**k)(X[0, 0]) for k in kwargs] + t_dim = max([spec.shape[-1] for spec in ta_spectrograms]) + f_dim = max([spec.shape[-2] for spec in ta_spectrograms]) + expected_y = torch.zeros([f_dim, t_dim]) + + for i in range(t_dim): + for j in range(f_dim): + max_value = 0 + for spec in ta_spectrograms: + t_idx = int(i / t_dim * spec.shape[-1]) + f_idx = int(j / f_dim * spec.shape[-2]) + max_value = max(max_value, spec[f_idx, t_idx]) + expected_y[j, i] += max_value + + assert torch.allclose(y[0, 0], expected_y, rtol=1e-6)