diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6f086c3..11f43dd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,22 +4,21 @@ on: push jobs: test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python_version: ['3.8', '3.9', '3.10', '3.11'] + runs-on: [self-hosted, gpu] steps: - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python_version }} - - name: install dependencies - run: python3 -m pip install .[dev] + - name: test gpu is available + run: nvidia-smi + + - name: build image + run: make build - name: test - run: pytest --cov pitch_detectors --cov-fail-under=90 + run: make test + + - name: test-no-gpu + run: make test-no-gpu publish-to-pypi-and-github-release: if: "startsWith(github.ref, 'refs/tags/')" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fa65608 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..86b32fa --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,64 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-added-large-files + - id: check-yaml + - id: check-json + - id: check-ast + - id: check-byte-order-marker + - id: check-builtin-literals + - id: check-case-conflict + - id: check-docstring-first + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + - id: trailing-whitespace + - id: check-merge-conflict + - id: detect-private-key + - id: double-quote-string-fixer + - id: name-tests-test + - id: requirements-txt-fixer + + - repo: https://github.com/asottile/add-trailing-comma + rev: v2.3.0 + hooks: + - id: add-trailing-comma + + - repo: https://github.com/asottile/pyupgrade + rev: v3.1.0 + hooks: + - id: pyupgrade + + - repo: https://github.com/pre-commit/mirrors-autopep8 + rev: v1.7.0 + hooks: + - id: autopep8 + + - repo: https://github.com/PyCQA/autoflake + rev: v1.7.6 + hooks: + - id: autoflake + + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.254 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + + - repo: https://github.com/PyCQA/pylint + rev: v2.17.0 + hooks: + - id: pylint + additional_dependencies: ["pylint-per-file-ignores"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.982 + hooks: + - id: mypy + additional_dependencies: [types-redis] + + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.299 + hooks: + - id: pyright diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..3d842d5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM alpine/curl as downloader +RUN curl -L https://tfhub.dev/google/spice/2?tf-hub-format=compressed --output spice_2.tar.gz && \ + mkdir /spice_model && \ + tar xvf spice_2.tar.gz --directory /spice_model && \ + rm spice_2.tar.gz + +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 + +# https://github.com/NVIDIA/nvidia-docker/wiki/Usage +# https://github.com/NVIDIA/nvidia-docker/issues/531 +ENV NVIDIA_DRIVER_CAPABILITIES compute,video,utility + +COPY --from=downloader /spice_model /spice_model + +RUN apt-get update && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y software-properties-common && \ + add-apt-repository -y ppa:deadsnakes/ppa && \ + apt-get install -y python3.10-venv libsndfile-dev libasound-dev portaudio19-dev + +# https://pythonspeed.com/articles/activate-virtualenv-dockerfile/ +ENV VIRTUAL_ENV=/venv +RUN python3.10 -m venv $VIRTUAL_ENV +ENV PATH="$VIRTUAL_ENV/bin:$PATH" + +WORKDIR /app +COPY pyproject.toml . +COPY README.md . +COPY pitch_detectors /app/pitch_detectors +COPY tests /app/tests +COPY data /app/data + +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --upgrade pip setuptools wheel && \ + pip install .[dev] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a247d51 --- /dev/null +++ b/Makefile @@ -0,0 +1,34 @@ +IMAGE = tandav/pitch-detectors:11.8.0-cudnn8-devel-ubuntu22.04 + +.PHONY: build +build: + DOCKER_BUILDKIT=1 docker build --progress=plain -t $(IMAGE) . + +.PHONY: push +push: + docker push $(IMAGE) + docker push tandav/pitch-detectors:latest + +.PHONY: test +test: build + docker run --rm -t --gpus all \ + -e PITCH_DETECTORS_GPU=true \ + $(IMAGE) \ + pytest -v --cov pitch_detectors + +.PHONY: test-no-gpu +test-no-gpu: build + docker run --rm -t \ + -e PITCH_DETECTORS_GPU=false \ + $(IMAGE) \ + pytest -v --cov pitch_detectors + +.PHONY: evaluation +evaluation: build + eval "$$(cat .env)"; \ + docker run --rm -t --gpus all \ + -e PITCH_DETECTORS_GPU=true \ + -e REDIS_URL=$$REDIS_URL \ + -v /home/tandav/Downloads/MIR-1K:/app/MIR-1K \ + $(IMAGE) \ + python pitch_detectors/evaluation.py diff --git a/README.md b/README.md index 12163f5..0aefec2 100644 --- a/README.md +++ b/README.md @@ -2,26 +2,48 @@ collection of pitch detection algorithms with unified interface ## list of algorithms -1. PraatAC [cpu] -1. PraatCC [cpu] -1. PraatSHS [cpu] -1. Pyin [cpu] -1. Reaper [cpu] -1. Yaapt [cpu] -1. Rapt [cpu] -1. World [cpu] -1. TorchYin [cpu] -1. Crepe [cpu, gpu] -1. TorchCrepe [cpu, gpu] -1. Swipe [cpu, gpu] + +| algorithm | cpu | gpu | accuracy [1] | +|------------------------------------------------------------------------------------------------------------|-----|-----|--------------| +| [PraatAC](https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_ac) | ✓ | | 0.880 | +| [PraatCC](https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_cc) | ✓ | | 0.893 | +| [PraatSHS](https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_shs) | ✓ | | 0.618 | +| [Pyin](https://librosa.org/doc/latest/generated/librosa.pyin.html) | ✓ | | 0.886 | +| [Reaper](https://github.com/r9y9/pyreaper) | ✓ | | 0.826 | +| [Yaapt](http://bjbschmitt.github.io/AMFM_decompy/pYAAPT.html#amfm_decompy.pYAAPT.yaapt) | ✓ | | 0.759 | +| [World](https://github.com/JeremyCCHsu/Python-Wrapper-for-World-Vocoder) | ✓ | | 0.873 | +| [TorchYin](https://github.com/brentspell/torch-yin) | ✓ | | 0.886 | +| [Rapt](https://pysptk.readthedocs.io/en/stable/generated/pysptk.sptk.rapt.html) | ✓ | | 0.859 | +| [Swipe](https://pysptk.readthedocs.io/en/stable/generated/pysptk.sptk.swipe.html) | ✓ | | 0.871 | +| [Crepe](https://github.com/marl/crepe) | ✓ | ✓ | 0.802 | +| [TorchCrepe](https://github.com/maxrmorrison/torchcrepe) | ✓ | ✓ | 0.817 | +| [Spice](https://ai.googleblog.com/2019/11/spice-self-supervised-pitch-estimation.html) | ✓ | ✓ | 0.908 | + -## additional features -- robust (vote-based + median) averaging of pitch -- json import/export +- [1] accuracy is mean [raw pitch accuracy](http://craffel.github.io/mir_eval/#mir_eval.melody.raw_pitch_accuracy) on 1000 samples of [MIR-1K](https://www.kaggle.com/datasets/datongmuyuyi/mir1k) dataset ## install +all agorithms tested on python3.10, this is recommended python version to use ```bash pip install pitch-detectors ``` ## usage + +```python +from scipy.io import wavfile +from pitch_detectors import algorithms +import matplotlib.pyplot as plt + +fs, a = wavfile.read('data/b1a5da49d564a7341e7e1327aa3f229a.wav') +pitch = algorithms.Crepe(a, fs) +plt.plot(pitch.t, pitch.f0) +plt.show() +``` + +![Alt text](data/b1a5da49d564a7341e7e1327aa3f229a.png) + + +## additional features +- [ ] robust (vote-based + median) averaging of pitch +- [ ] json import/export diff --git a/data/b1a5da49d564a7341e7e1327aa3f229a.png b/data/b1a5da49d564a7341e7e1327aa3f229a.png new file mode 100644 index 0000000..632d896 Binary files /dev/null and b/data/b1a5da49d564a7341e7e1327aa3f229a.png differ diff --git a/tests/data/b1a5da49d564a7341e7e1327aa3f229a.wav b/data/b1a5da49d564a7341e7e1327aa3f229a.wav similarity index 100% rename from tests/data/b1a5da49d564a7341e7e1327aa3f229a.wav rename to data/b1a5da49d564a7341e7e1327aa3f229a.wav diff --git a/pitch_detectors/algorithms.py b/pitch_detectors/algorithms.py index d36184f..d010b66 100644 --- a/pitch_detectors/algorithms.py +++ b/pitch_detectors/algorithms.py @@ -1,6 +1,7 @@ +import os + import numpy as np -from pitch_detectors import config from pitch_detectors import util @@ -13,14 +14,40 @@ def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 6 self.hz_min = hz_min self.hz_max = hz_max self.seconds = len(a) / fs - self.f0 = None - - def dict(self): + self.f0: np.ndarray + self.t: np.ndarray + if ( + os.environ.get('PITCH_DETECTORS_GPU') == 'true' and + self.use_gpu and + not self.gpu_available() + ): + raise ConnectionError(f'gpu must be available for {self.name()} algorithm') + + def dict(self) -> dict[str, list[float | None]]: return {'f0': util.nan_to_none(self.f0.tolist()), 't': self.t.tolist()} @classmethod - def name(cls): - return cls.__class__.__name__ + def name(cls) -> str: + return cls.__name__ + + def gpu_available(self) -> bool: + return False + + +class TensorflowGPU: + use_gpu = True + + def gpu_available(self) -> bool: + import tensorflow as tf + return bool(tf.config.experimental.list_physical_devices('GPU')) + + +class TorchGPU: + use_gpu = True + + def gpu_available(self) -> bool: + import torch + return torch.cuda.is_available() # type: ignore class PraatAC(PitchDetector): @@ -69,17 +96,13 @@ def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 6 self.t = np.linspace(0, self.seconds, f0.shape[0]) -class Crepe(PitchDetector): - use_gpu = True - +class Crepe(TensorflowGPU, PitchDetector): def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 600, confidence_threshold: float = 0.8): import crepe import tensorflow as tf super().__init__(a, fs, hz_min, hz_max) gpus = tf.config.experimental.list_physical_devices('GPU') - if not gpus: - raise RuntimeError('Crepe requires a GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) @@ -87,22 +110,19 @@ def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 6 self.f0[self.confidence < confidence_threshold] = np.nan -class TorchCrepe(PitchDetector): - use_gpu = True - +class TorchCrepe(TorchGPU, PitchDetector): def __init__( self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 600, confidence_threshold: float = 0.8, - batch_size=2048, - device=None, + batch_size: int = 2048, + device: str | None = None, ): import torch import torchcrepe if device is None: - torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + torch.device(device) super().__init__(a, fs, hz_min, hz_max) - if not torch.cuda.is_available(): - raise RuntimeError('TorchCrepe requires a GPU') f0, confidence = torchcrepe.predict( torch.from_numpy(a[np.newaxis, ...]), @@ -128,7 +148,7 @@ def __init__( class Yaapt(PitchDetector): def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 600): import amfm_decompy.basic_tools as basic - import amfm_decompy.pYAAPT as pYAAPT + from amfm_decompy import pYAAPT super().__init__(a, fs, hz_min, hz_max) self.signal = basic.SignalObj(data=self.a, fs=self.fs) f0 = pYAAPT.yaapt(self.signal, f0_min=self.hz_min, f0_max=self.hz_max, frame_length=15) @@ -160,10 +180,10 @@ def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 6 class Reaper(PitchDetector): def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 600): - import dsplib.scale import pyreaper + from dsplib.scale import minmax_scaler int16_info = np.iinfo(np.int16) - a = dsplib.scale.minmax_scaler(a, np.min(a), np.max(a), int16_info.min, int16_info.max).round().astype(np.int16) + a = minmax_scaler(a, np.min(a), np.max(a), int16_info.min, int16_info.max).round().astype(np.int16) super().__init__(a, fs, hz_min, hz_max) pm_times, pm, f0_times, f0, corr = pyreaper.reaper(self.a, fs=self.fs, minf0=self.hz_min, maxf0=self.hz_max, frame_period=0.01) f0[f0 == -1] = np.nan @@ -171,41 +191,42 @@ def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 6 self.t = f0_times -class Spice(PitchDetector): - """https://ai.googleblog.com/2019/11/spice-self-supervised-pitch-estimation.html""" - use_gpu = True - +class Spice(TensorflowGPU, PitchDetector): def __init__( - self, a: np.ndarray, fs: int, - confidence_threshold=0.8, + self, + a: np.ndarray, + fs: int, + confidence_threshold: float = 0.8, expected_sample_rate: int = 16000, + spice_model_path: str = '/spice_model', ): import resampy import tensorflow as tf import tensorflow_hub as hub a = resampy.resample(a, fs, expected_sample_rate) super().__init__(a, fs) - model = hub.load(config.spice_model_path) + model = hub.load(spice_model_path) model_output = model.signatures['serving_default'](tf.constant(a, tf.float32)) confidence = 1.0 - model_output['uncertainty'] - f0 = self.output2hz(model_output['pitch'].numpy()) - f0[confidence < confidence_threshold] = np.nan + self.f0 = self.output2hz(model_output['pitch'].numpy()) + self.f0[confidence < confidence_threshold] = np.nan + self.t = np.linspace(0, self.seconds, self.f0.shape[0]) def output2hz( self, pitch_output: np.ndarray, - PT_OFFSET: float = 25.58, - PT_SLOPE: float = 63.07, - FMIN: float = 10.0, - BINS_PER_OCTAVE: float = 12.0, + pt_offset: float = 25.58, + pt_slope: float = 63.07, + fmin: float = 10.0, + bins_per_octave: float = 12.0, ) -> np.ndarray: """convert pitch from the model output [0.0, 1.0] range to absolute values in Hz.""" - cqt_bin = pitch_output * PT_SLOPE + PT_OFFSET - return FMIN * 2.0 ** (1.0 * cqt_bin / BINS_PER_OCTAVE) + cqt_bin = pitch_output * pt_slope + pt_offset + return fmin * 2.0 ** (1.0 * cqt_bin / bins_per_octave) class World(PitchDetector): - def __init__(self, a: np.ndarray, fs): + def __init__(self, a: np.ndarray, fs: int): import pyworld super().__init__(a, fs) f0, sp, ap = pyworld.wav2world(a.astype(float), fs) @@ -218,9 +239,9 @@ class TorchYin(PitchDetector): def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 600): import torch import torchyin - a = torch.from_numpy(a) super().__init__(a, fs, hz_min, hz_max) - f0 = torchyin.estimate(self.a, sample_rate=self.fs, pitch_min=self.hz_min, pitch_max=self.hz_max) + _a = torch.from_numpy(a) + f0 = torchyin.estimate(_a, sample_rate=self.fs, pitch_min=self.hz_min, pitch_max=self.hz_max) f0[f0 == 0] = np.nan self.f0 = f0[:-1] self.t = np.linspace(0, self.seconds, f0.shape[0])[1:] @@ -239,6 +260,7 @@ def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 6 Rapt, World, TorchYin, + Spice, ) cpu_algorithms = ( @@ -256,7 +278,7 @@ def __init__(self, a: np.ndarray, fs: int, hz_min: float = 75, hz_max: float = 6 gpu_algorithms = ( 'Crepe', 'TorchCrepe', - 'Swipe', + 'Spice', ) algorithms = cpu_algorithms + gpu_algorithms diff --git a/pitch_detectors/config.py b/pitch_detectors/config.py deleted file mode 100644 index bce54d0..0000000 --- a/pitch_detectors/config.py +++ /dev/null @@ -1 +0,0 @@ -spice_model_path = 'data/spice_model/' diff --git a/pitch_detectors/evaluation.py b/pitch_detectors/evaluation.py new file mode 100644 index 0000000..cc81bca --- /dev/null +++ b/pitch_detectors/evaluation.py @@ -0,0 +1,114 @@ +import argparse +import os +import time +from pathlib import Path + +import mir_eval +import numpy as np +import tqdm +from dsplib.scale import minmax_scaler +from musiclib.pitch import Pitch +from redis import Redis +from scipy.io import wavfile + +from pitch_detectors import algorithms + +MIR_1K_DIR = Path('MIR-1K') +WAV_DIR = MIR_1K_DIR / 'Wavfile' + + +def load_f0_true(wav_path: Path, seconds: float) -> tuple[np.ndarray, np.ndarray]: + p = Pitch() + pitch_label_dir = wav_path.parent.parent / 'PitchLabel' + f0_path = (pitch_label_dir / wav_path.stem).with_suffix('.pv') + f0 = [] + with open(f0_path) as f: + for _line in f: + line = _line.strip() + if line == '0': + f0.append(float('nan')) + else: + f0.append(p.note_i_to_hz(float(line))) + f0 = np.array(f0) + # t = np.arange(0.02, seconds - 0.02, 0.02) + # assert t.shape == f0.shape + t = np.linspace(0.02, seconds, len(f0)) + return t, f0 + + +def resample_f0( + pitch: algorithms.PitchDetector, + t_resampled: np.ndarray, +) -> np.ndarray: + f0_resampled = np.full_like(t_resampled, fill_value=np.nan) + notna_slices = np.ma.clump_unmasked(np.ma.masked_invalid(pitch.f0)) + for slice_ in notna_slices: + t_slice = pitch.t[slice_] + f0_slice = pitch.f0[slice_] + t_start, t_stop = t_slice[0], t_slice[-1] + mask = (t_start < t_resampled) & (t_resampled < t_stop) + t_interp = t_resampled[mask] + f0_interp = np.interp(t_interp, t_slice, f0_slice) + f0_resampled[mask] = f0_interp + return f0_resampled + + +def raw_pitch_accuracy( + ref_f0: np.ndarray, + est_f0: np.ndarray, + cent_tolerance: float = 50, +) -> float: + ref_voicing = np.isfinite(ref_f0) + est_voicing = np.isfinite(est_f0) + ref_cent = mir_eval.melody.hz2cents(ref_f0) + est_cent = mir_eval.melody.hz2cents(est_f0) + score: float = mir_eval.melody.raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance) + return score + + +def evaluate_one( + redis: Redis[str], + algorithm: type[algorithms.PitchDetector], + wav_path: Path, +) -> str: + key = f'pitch_detectors:evaluation:{algorithm.name()}:{wav_path.stem}' + if redis.exists(key): + return key + fs, a = wavfile.read(wav_path) + seconds = len(a) / fs + a = a[:, 1].astype(np.float32) + rescale = 100000 + a = minmax_scaler(a, a.min(), a.max(), -rescale, rescale).astype(np.float32) + t_true, f0_true = load_f0_true(wav_path, seconds) + pitch = algorithm(a, fs) + f0 = resample_f0(pitch, t_resampled=t_true) + score = raw_pitch_accuracy(f0_true, f0) + redis.hset( + key, mapping={ + 'raw_pitch_accuracy': score, + 'timestamp': int(time.time() * 1000), + }, + ) + return key + + +def evaluate_all(redis: Redis[str]) -> None: + t = tqdm.tqdm(sorted(WAV_DIR.glob('*.wav'))) + for wav_path in t: + for algorithm in tqdm.tqdm(algorithms.ALGORITHMS, leave=False): + key = evaluate_one(redis, algorithm, wav_path) + t.set_description(key) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--algorithm', type=str) + parser.add_argument('--file', type=str) + args = parser.parse_args() + if (args.algorithm is None) ^ (args.file is None): + raise ValueError('you must specify both algorithm and file or neither') + redis = Redis.from_url(os.environ['REDIS_URL'], decode_responses=True) + if args.algorithm is not None and args.file is not None: + evaluate_one(redis, algorithm=getattr(algorithms, args.algorithm), wav_path=WAV_DIR / args.file) + raise SystemExit(0) + evaluate_all(redis) diff --git a/pitch_detectors/util.py b/pitch_detectors/util.py index 4fd2276..c19979c 100644 --- a/pitch_detectors/util.py +++ b/pitch_detectors/util.py @@ -1,4 +1,9 @@ import math +from pathlib import Path + +import numpy as np +from dsplib.scale import minmax_scaler +from scipy.io import wavfile def nan_to_none(x: list[float]) -> list[float | None]: @@ -7,3 +12,9 @@ def nan_to_none(x: list[float]) -> list[float | None]: def none_to_nan(x: list[float | None]) -> list[float]: return [float('nan') if v is None else v for v in x] + + +def load_wav(path: Path | str, rescale: float = 100000) -> tuple[int, np.ndarray]: + fs, a = wavfile.read(path) + a = minmax_scaler(a, a.min(), a.max(), -rescale, rescale).astype(np.float32) + return fs, a diff --git a/pyproject.toml b/pyproject.toml index 6519e96..4821013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,11 @@ authors = [ ] description = "collection of pitch detection algorithms with unified interface" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.8,<3.11" dependencies = [ "AMFM-decompy", "crepe", - "dsplib==0.7.0", + "dsplib==0.7.2", "librosa", "numpy", "praat-parselmouth==0.4.1", @@ -32,6 +32,14 @@ dev = [ "bumpver", "pre-commit", "pytest", + "pytest-order", + "pytest-cov", + "mir_eval", + "tqdm", + "resampy", + "redis", + "musiclib", + "python-dotenv", ] [project.urls] @@ -46,6 +54,9 @@ issues = "https://github.com/tandav/pitch-detectors/issues" requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.setuptools] +packages = ["pitch_detectors"] + # ============================================================================== [tool.bumpver] @@ -116,14 +127,7 @@ extend-select = [ ] ignore = [ "E501", # line too long - "E731", - "E701", - "E702", - "F403", # star imports - "F405", # star imports - "B008", "PLR0913", - "TCH003", ] [tool.ruff.per-file-ignores] @@ -135,17 +139,44 @@ force-single-line = true # ============================================================================== +[tool.pylint.MASTER] +load-plugins=[ + "pylint_per_file_ignores", +] + +[tool.pylint.BASIC] +good-names = [ + "a", + "fs", + "f0", + "t", + "x", + "pm", + "sp", + "ap", + "f", + "p", +] + [tool.pylint.messages-control] disable = [ - "C0321","C3001","C0116","C0301","C0103","C0115","C0114", - "W1514", - "W0401", # wildcard import - "W0614", - "W1113", - "R0903", - "E0401", + "missing-function-docstring", + "missing-class-docstring", + "missing-module-docstring", + "line-too-long", + "import-outside-toplevel", + "unused-variable", + "too-many-arguments", + "import-error", + "too-few-public-methods", + "unspecified-encoding", + "redefined-outer-name", ] +[tool.pylint-per-file-ignores] +"/tests/" = "redefined-outer-name" + + # ============================================================================== [tool.autopep8] @@ -154,3 +185,9 @@ recursive = true aggressive = 3 # ============================================================================== + +[tool.pyright] +venvPath = "/home/tandav/.virtualenvs" +venv = "pitch-detectors" + +# ============================================================================== diff --git a/tests/algorithms_test.py b/tests/algorithms_test.py index d0222ab..f45cd6f 100644 --- a/tests/algorithms_test.py +++ b/tests/algorithms_test.py @@ -1,20 +1,28 @@ +import dataclasses +from pathlib import Path + import numpy as np import pytest + +from pitch_detectors import util from pitch_detectors.algorithms import ALGORITHMS -from scipy.io import wavfile -from pathlib import Path -from dsplib.scale import minmax_scaler + + +@dataclasses.dataclass +class Record: + a: np.ndarray + fs: int @pytest.fixture -def a_fs(rescale: float = 100000): - """audio and fs""" - fs, a = wavfile.read(Path(__file__).parent / 'data' / 'b1a5da49d564a7341e7e1327aa3f229a.wav') - a = minmax_scaler(a, a.min(), a.max(), -rescale, rescale).astype(np.float32) - assert a.dtype == np.float32 - yield a, fs +def record(): + fs, a = util.load_wav(Path(__file__).parent.parent / 'data' / 'b1a5da49d564a7341e7e1327aa3f229a.wav') + return Record(a, fs) + +@pytest.mark.order(3) +@pytest.mark.filterwarnings('ignore:pkg_resources is deprecated as an API') +@pytest.mark.filterwarnings('ignore:Deprecated call to `pkg_resources.declare_namespace') @pytest.mark.parametrize('algorithm', ALGORITHMS) -def test_detection(algorithm, a_fs): - a, fs = a_fs - p = algorithm(a, fs) +def test_detection(algorithm, record): + algorithm(record.a, record.fs) diff --git a/tests/gpu_test.py b/tests/gpu_test.py new file mode 100644 index 0000000..71814fe --- /dev/null +++ b/tests/gpu_test.py @@ -0,0 +1,24 @@ +import os +import subprocess + +import pytest +import tensorflow as tf +import torch + + +@pytest.mark.order(0) +@pytest.mark.skipif(os.environ.get('PITCH_DETECTORS_GPU') == 'false', reason='gpu is not used') +def test_nvidia_smi(): + subprocess.check_call('nvidia-smi') + + +@pytest.mark.order(1) +@pytest.mark.skipif(os.environ.get('PITCH_DETECTORS_GPU') == 'false', reason='gpu is not used') +def test_tensorflow(): + assert tf.config.experimental.list_physical_devices('GPU') + + +@pytest.mark.order(2) +@pytest.mark.skipif(os.environ.get('PITCH_DETECTORS_GPU') == 'false', reason='gpu is not used') +def test_pytorch(): + assert torch.cuda.is_available()