From 3dafe6d3d2b430a8c6c25310dbf58e94d6295359 Mon Sep 17 00:00:00 2001 From: Doug Turnbull Date: Tue, 26 Dec 2023 10:12:29 -0500 Subject: [PATCH] Fix mypy issues --- .flake8 | 13 +++++++++++++ Makefile | 2 +- mypy.ini | 7 +++++++ requirements.txt | 22 +++++++++++++++++++--- searcharray/phrase/middle_out.py | 3 ++- searcharray/postings.py | 3 ++- test/test_utils.py | 6 ++++-- 7 files changed, 48 insertions(+), 8 deletions(-) create mode 100644 .flake8 create mode 100644 mypy.ini diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..538f1ba --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +ignore = + # https://github.com/psf/black#slices + E203, + # https://github.com/psf/black#line-length + E501, + # https://github.com/psf/black#line-breaks--binary-operators + W503, + # we don't require docstrings by default + D100, D101, D102, D103, D104, D105, D106, D107, + # this check is overly picky about RNG usage when it might not matter + # https://github.snooguts.net/reddit/docker-reddit-lint.py/pull/2 + DUO102, diff --git a/Makefile b/Makefile index 1257efc..115a17c 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ test: deps lint: deps @echo "Linting..." python -m flake8 --max-line-length=120 --ignore=E203,W503,E501,E722,E731,W605 --exclude=venv,build,dist,docs,*.egg-info,*.egg,*.pyc,*.pyo,*.git,__pycache__,.pytest_cache,.benchmarks - mypy searcharray tests + mypy searcharray test benchmark_dry_run: deps diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..1eaf7e6 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,7 @@ +[mypy] + +[mypy-sortednp] +ignore_missing_imports = True + +[mypy-pandas.tests.*] +ignore_missing_imports = True diff --git a/requirements.txt b/requirements.txt index 99cc41a..382075c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,14 @@ asttokens==2.4.1 build==1.0.3 certifi==2023.7.22 charset-normalizer==3.3.2 +contourpy==1.2.0 coverage==7.3.2 +cycler==0.12.1 decorator==5.1.1 docutils==0.20.1 executing==2.0.1 +flake8==6.1.0 +fonttools==4.46.0 idna==3.4 importlib-metadata==6.8.0 iniconfig==2.0.0 @@ -14,24 +18,34 @@ ipython==8.17.2 jaraco.classes==3.3.0 jedi==0.19.1 keyring==24.3.0 +kiwisolver==1.4.5 markdown-it-py==3.0.0 +matplotlib==3.8.2 matplotlib-inline==0.1.6 +mccabe==0.7.0 mdurl==0.1.2 more-itertools==10.1.0 +mypy==1.8.0 +mypy-extensions==1.0.0 nh3==0.2.14 numpy==1.26.1 packaging==23.2 pandas==2.1.2 +pandas-stubs==2.1.4.231218 parso==0.8.3 pexpect==4.8.0 +Pillow==10.1.0 pkginfo==1.9.6 pluggy==1.3.0 prompt-toolkit==3.0.39 ptyprocess==0.7.0 pure-eval==0.2.2 py-cpuinfo==9.0.0 +pycodestyle==2.11.1 +pyflakes==3.1.0 pygal==3.0.4 Pygments==2.16.1 +pyparsing==3.1.1 pyproject_hooks==1.0.0 pyroaring==0.4.4 pytest==7.4.3 @@ -44,16 +58,18 @@ requests==2.31.0 requests-toolbelt==1.0.0 rfc3986==2.0.0 rich==13.6.0 -scipy==1.11.3 +scipy==1.11.4 searcharray==0.0.1 six==1.16.0 +snakeviz==2.2.0 sortednp==0.4.1 stack-data==0.6.3 +tornado==6.4 traitlets==5.13.0 twine==4.0.2 +types-pytz==2023.3.1.1 +typing_extensions==4.9.0 tzdata==2023.3 urllib3==2.1.0 wcwidth==0.2.9 zipp==3.17.0 -flake8==6.1.0 -mypy==1.8.0 diff --git a/searcharray/phrase/middle_out.py b/searcharray/phrase/middle_out.py index fb48812..50db4d3 100644 --- a/searcharray/phrase/middle_out.py +++ b/searcharray/phrase/middle_out.py @@ -8,7 +8,7 @@ import numpy as np import sortednp as snp from copy import deepcopy -from typing import List, Tuple, Dict, Union +from typing import List, Tuple, Dict, Union, cast from searcharray.utils.roaringish import RoaringishEncoder, convert_keys import numbers import logging @@ -293,6 +293,7 @@ def positions(self, term_id: int, key) -> Union[List[np.ndarray], np.ndarray]: return [np.array([], dtype=np.uint32)] if len(decoded) != len(doc_ids): # Fill non matches + decoded = cast(List[Tuple[np.uint64, np.ndarray]], decoded) as_dict: Dict[np.uint64, np.ndarray] = dict(decoded) decs = [] for doc_id in doc_ids: diff --git a/searcharray/postings.py b/searcharray/postings.py index b369c31..9defd6a 100644 --- a/searcharray/postings.py +++ b/searcharray/postings.py @@ -603,7 +603,8 @@ def doc_freq(self, token: str) -> int: raise TypeError("Expected a string") # Count number of rows where the term appears term_freq = self.term_freq(token) - return np.sum(term_freq > 0) + gt_0 = term_freq > 0 + return np.sum(gt_0).astype(int) def doc_lengths(self) -> np.ndarray: return self.doc_lens diff --git a/test/test_utils.py b/test/test_utils.py index 983b111..2aba641 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,5 @@ import pytest -from typing import Dict, Any +from typing import Dict, Any, cast, Sequence, Type, Union import cProfile import sys @@ -7,7 +7,7 @@ def w_scenarios(scenarios: Dict[str, Dict[str, Any]]): """Decorate for parametrizing tests that names the scenarios and params.""" return pytest.mark.parametrize( - [key for key in scenarios.values()][0].keys(), + cast(Sequence[str], [key for key in scenarios.values()][0].keys()), [tuple(scenario.values()) for scenario in scenarios.values()], ids=list(scenarios.keys()) ) @@ -42,6 +42,8 @@ def run(self, func, *args, **kwargs): return rval +Profiler: Union[Type[JustBenchmarkProfiler], Type[CProfileProfiler]] + if '--benchmark-disable' in sys.argv: Profiler = CProfileProfiler else: