diff --git a/tests/test_zkpytb_logging.py b/tests/test_zkpytb_logging.py index b47f0a8..558ced1 100644 --- a/tests/test_zkpytb_logging.py +++ b/tests/test_zkpytb_logging.py @@ -121,9 +121,9 @@ def test_setup_simple_console_and_file_logger_4(tmpdir): """ logger_name = 'logtest4' with LogCapture('zkpytb.logging') as lc2: - logger = setup_simple_console_and_file_logger(logger_name, logdir=123) + logger = setup_simple_console_and_file_logger(logger_name, logdir=123) # type: ignore err_msg = ('Invalid type for argument "logdir". ' - 'Expected "str", "bytes" or "pathlib.Path". ' + 'Expected "str" or "pathlib.Path". ' 'Received type: {}'.format(type(123))) lc2.check(('zkpytb.logging', 'ERROR', err_msg)) assert len(logger.handlers) == 1 diff --git a/tests/test_zkpytb_pandas.py b/tests/test_zkpytb_pandas.py index 498608e..0be1239 100644 --- a/tests/test_zkpytb_pandas.py +++ b/tests/test_zkpytb_pandas.py @@ -49,7 +49,7 @@ def generate_test_df(df_len, rand_seed): 'd': np.arange(df_len, dtype=float), 'e': np.random.RandomState(rand_seed).rand(df_len), 'f': np.random.RandomState(rand_seed).randn(df_len), - 'g': np.random.RandomState(rand_seed).random_integers(0, 100, df_len), + 'g': np.random.RandomState(rand_seed).randint(0, 100 + 1, df_len), 'n': np.zeros(df_len) + np.nan, }, index=range(df_len)) @@ -116,11 +116,13 @@ def test_move_col_to_beginning_of_df(df1): def test_compare_df_cols_mode1(df1, df2): res = compare_df_cols([df1, df2], ['e', 'f', 'g'], mode=1) + assert res is not None assert list(res.columns) == ['e_1', 'e_2', 'f_1', 'f_2', 'g_1', 'g_2'] def test_compare_df_cols_mode2(df1, df2): res = compare_df_cols([df1, df2], ['e', 'f', 'g'], mode=2) + assert res is not None assert list(res.columns) == ['e_1', 'f_1', 'g_1', 'e_2', 'f_2', 'g_2'] diff --git a/tests/test_zkpytb_utils.py b/tests/test_zkpytb_utils.py index c5eda10..f50dec8 100644 --- a/tests/test_zkpytb_utils.py +++ b/tests/test_zkpytb_utils.py @@ -6,6 +6,7 @@ import os import pytest import subprocess +from pathlib import Path from zkpytb.utils import ( @@ -24,29 +25,36 @@ def a_hash_method(request): def test_hashstring_hashmethods(a_hash_method): hash_res1 = hashstring(b'', hash_method=a_hash_method) hash_res2 = hashstring(b'test', hash_method=a_hash_method) - assert isinstance(hash_res1, str) - assert len(hash_res1) > 0 + hash_res3 = hashstring(b'test', hash_method=a_hash_method) + assert all(isinstance(res, str) for res in [hash_res1, hash_res2, hash_res3]) + assert all(len(res) > 0 for res in [hash_res1, hash_res2, hash_res3]) assert hash_res1 != hash_res2 + assert hash_res2 == hash_res3 -def test_hashfile_hashmethods(a_hash_method, tmpdir): - f1 = tmpdir.join("hashfile1.txt") - f2 = tmpdir.join("hashfile2.txt") - f1.write("content1") - f2.write("content2") +def test_hashfile_hashmethods(a_hash_method, tmp_path: Path): + f1 = tmp_path / "hashfile1.txt" + f2 = tmp_path / "hashfile2.txt" + f3 = tmp_path / "hashfile3.txt" + f1.write_text("content1") + f2.write_text("content2") + f3.write_text("content2") hashfile_res1 = hashfile(f1, hash_method=a_hash_method) hashfile_res2 = hashfile(f2, hash_method=a_hash_method) - assert isinstance(hashfile_res1, str) - assert len(hashfile_res1) > 0 + hashfile_res3 = hashfile(f3, hash_method=a_hash_method) + assert all(isinstance(res, str) for res in [hashfile_res1, hashfile_res2, hashfile_res3]) + assert all(len(res) > 0 for res in [hashfile_res1, hashfile_res2, hashfile_res3]) assert hashfile_res1 != hashfile_res2 + assert hashfile_res2 == hashfile_res3 @pytest.mark.xfail(reason="Tox temp dir is located under the main git repository of the project...") -def test_get_git_hash_nogit(tmpdir): - curdir = os.getcwd() +def test_get_git_hash_nogit(tmp_path: Path): + curdir = Path.cwd() try: - nogit_dir = tmpdir.mkdir("nogit_dir") - os.chdir(str(nogit_dir)) + nogit_dir = (tmp_path / "nogit_dir").mkdir() + assert nogit_dir is not None + os.chdir(nogit_dir) git_hash = get_git_hash() finally: os.chdir(curdir) @@ -54,11 +62,12 @@ def test_get_git_hash_nogit(tmpdir): assert git_hash == '' -def test_get_git_hash_emptygit(tmpdir): - curdir = os.getcwd() +def test_get_git_hash_emptygit(tmp_path: Path): + curdir = Path.cwd() try: - emptygit_dir = tmpdir.mkdir("emptygit_dir") - os.chdir(str(emptygit_dir)) + emptygit_dir: Path = tmp_path / "emptygit_dir" + emptygit_dir.mkdir() + os.chdir(emptygit_dir) subprocess.check_call(['git', 'init']) git_hash = get_git_hash() finally: @@ -67,11 +76,12 @@ def test_get_git_hash_emptygit(tmpdir): assert git_hash == '' -def test_get_git_hash_minimalgit(tmpdir): - curdir = os.getcwd() +def test_get_git_hash_minimalgit(tmp_path: Path): + curdir = Path.cwd() try: - minimalgit_dir = tmpdir.mkdir("minimalgit_dir") - minimalgit_dir.join('testfile.txt').write('minimalgit') + minimalgit_dir: Path = tmp_path / "minimalgit_dir" + minimalgit_dir.mkdir() + (minimalgit_dir / 'testfile.txt').write_text('minimalgit') os.chdir(str(minimalgit_dir)) subprocess.check_call(['git', 'init']) subprocess.check_call(['git', 'config', 'user.email', 'test@domain.invalid']) diff --git a/zkpytb/dicts.py b/zkpytb/dicts.py index eac442f..b8a625c 100644 --- a/zkpytb/dicts.py +++ b/zkpytb/dicts.py @@ -9,6 +9,7 @@ import json from collections import OrderedDict +from typing import Any, Callable, Dict, Iterator, List from zkpytb.json import JSONEncoder @@ -43,18 +44,18 @@ class AutoOrderedDict(OrderedDict, AutoDict): _base_class = OrderedDict -def filter_dict_callfunc(dict_in, func): +def filter_dict_callfunc(dict_in: Dict, func: Callable[..., bool]) -> Dict: assert isinstance(dict_in, dict) assert callable(func) return {k: v for (k, v) in dict_in.items() if func(k, v)} -def filter_dict_only_scalar_values(dict_in): +def filter_dict_only_scalar_values(dict_in: Dict) -> Dict: assert isinstance(dict_in, dict) return {k: v for (k, v) in dict_in.items() if not hasattr(v, '__iter__')} -def filter_dict_with_keylist(dict_in, keylist, blacklistmode=False): +def filter_dict_with_keylist(dict_in: Dict, keylist: List, blacklistmode=False) -> Dict: assert isinstance(dict_in, dict) assert isinstance(keylist, list) if blacklistmode: @@ -63,7 +64,7 @@ def filter_dict_with_keylist(dict_in, keylist, blacklistmode=False): return {k: v for (k, v) in dict_in.items() if k in keylist} -def mergedicts(dict1, dict2): +def mergedicts(dict1, dict2) -> Iterator[Any]: assert isinstance(dict1, dict) assert isinstance(dict2, dict) for k in set(dict1.keys()).union(dict2.keys()): @@ -81,11 +82,11 @@ def mergedicts(dict1, dict2): yield (k, dict2[k]) -def dict_stable_json_repr(dict_in): +def dict_stable_json_repr(dict_in: Dict) -> str: return json.dumps(dict_in, sort_keys=True, cls=JSONEncoder) -def hashdict(dict_in, method='sha1'): +def hashdict(dict_in: Dict, method='sha1') -> str: assert isinstance(dict_in, dict) h = hashlib.new(method) dict_repr = dict_stable_json_repr(dict_in) @@ -93,7 +94,7 @@ def hashdict(dict_in, method='sha1'): return h.hexdigest() -def dict_values_map(f, d): +def dict_values_map(f: Callable, d: Dict) -> Dict: """ Simple helper to apply a function to the values of a dictionary. diff --git a/zkpytb/json.py b/zkpytb/json.py index b29e142..cca095f 100644 --- a/zkpytb/json.py +++ b/zkpytb/json.py @@ -16,6 +16,7 @@ class JSONEncoder(json.JSONEncoder): """ A custom JSONEncoder that can handle a bit more data types than the one from stdlib. """ + def default(self, o): # early passthrough if it works by default try: diff --git a/zkpytb/logging.py b/zkpytb/logging.py index 16b9258..241f04d 100644 --- a/zkpytb/logging.py +++ b/zkpytb/logging.py @@ -1,14 +1,18 @@ import logging import logging.config from pathlib import Path +from typing import Mapping, Optional, Union mylogger = logging.getLogger('zkpytb.logging') -def setup_simple_console_and_file_logger(logger_name, logfile=True, - logdir=None, logfilename=None, - log_level='DEBUG', options=None): +def setup_simple_console_and_file_logger(logger_name: str, + logfile=True, + logdir: Optional[Union[str, Path]] = None, + logfilename: Optional[Union[str, Path]] = None, + log_level: str = 'DEBUG', + options: Optional[Mapping[str, str]] = None) -> logging.Logger: """ TODOC """ @@ -26,7 +30,7 @@ def setup_simple_console_and_file_logger(logger_name, logfile=True, log_directory = Path(logdir) except TypeError: mylogger.exception('Invalid type for argument "logdir". ' - 'Expected "str", "bytes" or "pathlib.Path". ' + 'Expected "str" or "pathlib.Path". ' 'Received type: {}'.format(type(logdir))) do_file_logging = False else: diff --git a/zkpytb/pandas.py b/zkpytb/pandas.py index d9bcd0e..d724adf 100644 --- a/zkpytb/pandas.py +++ b/zkpytb/pandas.py @@ -4,6 +4,8 @@ Author: Marc Gallet (2017) """ +from typing import List, Optional, Tuple + try: import numpy as np import pandas as pd @@ -17,10 +19,12 @@ # More percentiles when using pd.describe() -extended_percentiles = [.01, .05, .25, .5, .75, .95, .99] +extended_percentiles: List[float] = [.01, .05, .25, .5, .75, .95, .99] -def tdescr(df_in, percentiles=None, disp=True): +def tdescr(df_in: pd.DataFrame, + percentiles: Optional[List[float]] = None, + disp: bool = True) -> pd.DataFrame: """ Helper function to display and return the transposition of the output of DataFrame.describe(). This means that @@ -48,7 +52,7 @@ def tdescr(df_in, percentiles=None, disp=True): return tdescr_out -def df_query_with_ratio(df_in, query, ratio_name='ratio'): +def df_query_with_ratio(df_in: pd.DataFrame, query: str, ratio_name='ratio') -> Tuple[pd.DataFrame, float]: """ This function calls the .query() method on a DataFrame and additionally computes the ratio of resulting rows @@ -63,7 +67,7 @@ def df_query_with_ratio(df_in, query, ratio_name='ratio'): return df_out, ratio -def remove_outliers(df_in, column, sigma=3): +def remove_outliers(df_in: pd.DataFrame, column, sigma: float = 3) -> pd.DataFrame: """ Very simple filter that removes outlier rows from a DataFrame based on the distance from the @@ -72,7 +76,7 @@ def remove_outliers(df_in, column, sigma=3): return df_in[np.abs(df_in[column] - df_in[column].mean()) <= (sigma * df_in[column].std())] -def only_outliers(df_in, column, sigma=3): +def only_outliers(df_in: pd.DataFrame, column, sigma: float = 3) -> pd.DataFrame: """ Very simple filter that only keeps outlier rows from a DataFrame based on the distance from the @@ -81,7 +85,7 @@ def only_outliers(df_in, column, sigma=3): return df_in[np.abs(df_in[column] - df_in[column].mean()) > (sigma * df_in[column].std())] -def move_col_to_beginning_of_df(df_in, colname): +def move_col_to_beginning_of_df(df_in: pd.DataFrame, colname: str) -> pd.DataFrame: """ Small helper to move a column to the beginning of the DataFrame """ @@ -90,7 +94,7 @@ def move_col_to_beginning_of_df(df_in, colname): return df_in.reindex(columns=cols) -def compare_df_cols(df_list, col_list, mode=1): +def compare_df_cols(df_list: List[pd.DataFrame], col_list: List[str], mode=1) -> Optional[pd.DataFrame]: """ Helper to compare the values of common columns between different dataframes @@ -98,11 +102,11 @@ def compare_df_cols(df_list, col_list, mode=1): Mode 2: iterate over DataFrames as top level and columns as second level """ if mode == 1: - colstoconcat = [df.loc[:, col].rename(df.loc[:, col].name + '_' + str(i + 1)) + colstoconcat = [df.loc[:, col].rename(str(df.loc[:, col].name) + '_' + str(i + 1)) for col in col_list for i, df in enumerate(df_list)] elif mode == 2: - colstoconcat = [df.loc[:, col].rename(df.loc[:, col].name + '_' + str(i + 1)) + colstoconcat = [df.loc[:, col].rename(str(df.loc[:, col].name) + '_' + str(i + 1)) for i, df in enumerate(df_list) for col in col_list] else: @@ -140,7 +144,7 @@ def _percentile(x): return _percentile -def describe_numeric_1d(series): +def describe_numeric_1d(series: pd.Series): """ Patched version of pandas' .describe() function for Series which includes the calculation of the median absolute deviation and interquartile range @@ -149,7 +153,7 @@ def describe_numeric_1d(series): and all other stats are set to np.nan """ stat_index = (['count', 'mean', 'std', 'mad', 'mad_c1', 'iqr', 'min'] - + pd.io.formats.format.format_percentiles(extended_percentiles) + ['max']) + + pd.io.formats.format.format_percentiles(extended_percentiles) + ['max']) # type: ignore if series.empty: # [0, np.nan, np.nan, ..., np.nan] d = [0] + [np.nan] * (len(stat_index) - 1) diff --git a/zkpytb/priorityqueue.py b/zkpytb/priorityqueue.py index 2fe71c8..97bd3c3 100644 --- a/zkpytb/priorityqueue.py +++ b/zkpytb/priorityqueue.py @@ -16,12 +16,13 @@ class EmptyQueueError(Exception): class PriorityQueue: """Based on https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes""" + def __init__(self, name=''): self.name = name self.pq = [] # list of entries arranged in a heap self.entry_finder = {} # mapping of tasks to entries self.counter = itertools.count() # unique sequence count - self.num_tasks = 0 # track the number of tasks in the queue + self.num_tasks = 0 # track the number of tasks in the queue def __len__(self): return self.num_tasks diff --git a/zkpytb/utils.py b/zkpytb/utils.py index dc00680..5492daa 100644 --- a/zkpytb/utils.py +++ b/zkpytb/utils.py @@ -9,11 +9,13 @@ import logging import subprocess +from pathlib import Path +from typing import Union mylogger = logging.getLogger('zkpytb.utils') -def hashfile(filepath, hash_method='sha256', BLOCKSIZE=65536): +def hashfile(filepath: Union[str, Path], hash_method: str = 'sha256', BLOCKSIZE: int = 65536) -> str: """Hash a file""" hasher = hashlib.new(hash_method) @@ -27,15 +29,15 @@ def hashfile(filepath, hash_method='sha256', BLOCKSIZE=65536): return hasher.hexdigest() -def hashstring(inputstring, hash_method='sha256'): - """Hash a file""" +def hashstring(inputstring: bytes, hash_method: str = 'sha256') -> str: + """Hash a string""" hasher = hashlib.new(hash_method) hasher.update(inputstring) return hasher.hexdigest() -def get_git_hash(rev='HEAD'): +def get_git_hash(rev: str = 'HEAD') -> str: """Get the git hash of the current directory""" git_hash = ''