Skip to content

Commit

Permalink
Added typing hints and solved some warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
zertrin committed Nov 27, 2023
1 parent d2158b3 commit 7b252d6
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 51 deletions.
4 changes: 2 additions & 2 deletions tests/test_zkpytb_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def test_setup_simple_console_and_file_logger_4(tmpdir):
"""
logger_name = 'logtest4'
with LogCapture('zkpytb.logging') as lc2:
logger = setup_simple_console_and_file_logger(logger_name, logdir=123)
logger = setup_simple_console_and_file_logger(logger_name, logdir=123) # type: ignore
err_msg = ('Invalid type for argument "logdir". '
'Expected "str", "bytes" or "pathlib.Path". '
'Expected "str" or "pathlib.Path". '
'Received type: {}'.format(type(123)))
lc2.check(('zkpytb.logging', 'ERROR', err_msg))
assert len(logger.handlers) == 1
Expand Down
4 changes: 3 additions & 1 deletion tests/test_zkpytb_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def generate_test_df(df_len, rand_seed):
'd': np.arange(df_len, dtype=float),
'e': np.random.RandomState(rand_seed).rand(df_len),
'f': np.random.RandomState(rand_seed).randn(df_len),
'g': np.random.RandomState(rand_seed).random_integers(0, 100, df_len),
'g': np.random.RandomState(rand_seed).randint(0, 100 + 1, df_len),
'n': np.zeros(df_len) + np.nan,
}, index=range(df_len))

Expand Down Expand Up @@ -116,11 +116,13 @@ def test_move_col_to_beginning_of_df(df1):

def test_compare_df_cols_mode1(df1, df2):
res = compare_df_cols([df1, df2], ['e', 'f', 'g'], mode=1)
assert res is not None
assert list(res.columns) == ['e_1', 'e_2', 'f_1', 'f_2', 'g_1', 'g_2']


def test_compare_df_cols_mode2(df1, df2):
res = compare_df_cols([df1, df2], ['e', 'f', 'g'], mode=2)
assert res is not None
assert list(res.columns) == ['e_1', 'f_1', 'g_1', 'e_2', 'f_2', 'g_2']


Expand Down
52 changes: 31 additions & 21 deletions tests/test_zkpytb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import pytest
import subprocess
from pathlib import Path


from zkpytb.utils import (
Expand All @@ -24,41 +25,49 @@ def a_hash_method(request):
def test_hashstring_hashmethods(a_hash_method):
hash_res1 = hashstring(b'', hash_method=a_hash_method)
hash_res2 = hashstring(b'test', hash_method=a_hash_method)
assert isinstance(hash_res1, str)
assert len(hash_res1) > 0
hash_res3 = hashstring(b'test', hash_method=a_hash_method)
assert all(isinstance(res, str) for res in [hash_res1, hash_res2, hash_res3])
assert all(len(res) > 0 for res in [hash_res1, hash_res2, hash_res3])
assert hash_res1 != hash_res2
assert hash_res2 == hash_res3


def test_hashfile_hashmethods(a_hash_method, tmpdir):
f1 = tmpdir.join("hashfile1.txt")
f2 = tmpdir.join("hashfile2.txt")
f1.write("content1")
f2.write("content2")
def test_hashfile_hashmethods(a_hash_method, tmp_path: Path):
f1 = tmp_path / "hashfile1.txt"
f2 = tmp_path / "hashfile2.txt"
f3 = tmp_path / "hashfile3.txt"
f1.write_text("content1")
f2.write_text("content2")
f3.write_text("content2")
hashfile_res1 = hashfile(f1, hash_method=a_hash_method)
hashfile_res2 = hashfile(f2, hash_method=a_hash_method)
assert isinstance(hashfile_res1, str)
assert len(hashfile_res1) > 0
hashfile_res3 = hashfile(f3, hash_method=a_hash_method)
assert all(isinstance(res, str) for res in [hashfile_res1, hashfile_res2, hashfile_res3])
assert all(len(res) > 0 for res in [hashfile_res1, hashfile_res2, hashfile_res3])
assert hashfile_res1 != hashfile_res2
assert hashfile_res2 == hashfile_res3


@pytest.mark.xfail(reason="Tox temp dir is located under the main git repository of the project...")
def test_get_git_hash_nogit(tmpdir):
curdir = os.getcwd()
def test_get_git_hash_nogit(tmp_path: Path):
curdir = Path.cwd()
try:
nogit_dir = tmpdir.mkdir("nogit_dir")
os.chdir(str(nogit_dir))
nogit_dir = (tmp_path / "nogit_dir").mkdir()
assert nogit_dir is not None
os.chdir(nogit_dir)
git_hash = get_git_hash()
finally:
os.chdir(curdir)

assert git_hash == ''


def test_get_git_hash_emptygit(tmpdir):
curdir = os.getcwd()
def test_get_git_hash_emptygit(tmp_path: Path):
curdir = Path.cwd()
try:
emptygit_dir = tmpdir.mkdir("emptygit_dir")
os.chdir(str(emptygit_dir))
emptygit_dir: Path = tmp_path / "emptygit_dir"
emptygit_dir.mkdir()
os.chdir(emptygit_dir)
subprocess.check_call(['git', 'init'])
git_hash = get_git_hash()
finally:
Expand All @@ -67,11 +76,12 @@ def test_get_git_hash_emptygit(tmpdir):
assert git_hash == ''


def test_get_git_hash_minimalgit(tmpdir):
curdir = os.getcwd()
def test_get_git_hash_minimalgit(tmp_path: Path):
curdir = Path.cwd()
try:
minimalgit_dir = tmpdir.mkdir("minimalgit_dir")
minimalgit_dir.join('testfile.txt').write('minimalgit')
minimalgit_dir: Path = tmp_path / "minimalgit_dir"
minimalgit_dir.mkdir()
(minimalgit_dir / 'testfile.txt').write_text('minimalgit')
os.chdir(str(minimalgit_dir))
subprocess.check_call(['git', 'init'])
subprocess.check_call(['git', 'config', 'user.email', '[email protected]'])
Expand Down
15 changes: 8 additions & 7 deletions zkpytb/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json

from collections import OrderedDict
from typing import Any, Callable, Dict, Iterator, List

from zkpytb.json import JSONEncoder

Expand Down Expand Up @@ -43,18 +44,18 @@ class AutoOrderedDict(OrderedDict, AutoDict):
_base_class = OrderedDict


def filter_dict_callfunc(dict_in, func):
def filter_dict_callfunc(dict_in: Dict, func: Callable[..., bool]) -> Dict:
assert isinstance(dict_in, dict)
assert callable(func)
return {k: v for (k, v) in dict_in.items() if func(k, v)}


def filter_dict_only_scalar_values(dict_in):
def filter_dict_only_scalar_values(dict_in: Dict) -> Dict:
assert isinstance(dict_in, dict)
return {k: v for (k, v) in dict_in.items() if not hasattr(v, '__iter__')}


def filter_dict_with_keylist(dict_in, keylist, blacklistmode=False):
def filter_dict_with_keylist(dict_in: Dict, keylist: List, blacklistmode=False) -> Dict:
assert isinstance(dict_in, dict)
assert isinstance(keylist, list)
if blacklistmode:
Expand All @@ -63,7 +64,7 @@ def filter_dict_with_keylist(dict_in, keylist, blacklistmode=False):
return {k: v for (k, v) in dict_in.items() if k in keylist}


def mergedicts(dict1, dict2):
def mergedicts(dict1, dict2) -> Iterator[Any]:
assert isinstance(dict1, dict)
assert isinstance(dict2, dict)
for k in set(dict1.keys()).union(dict2.keys()):
Expand All @@ -81,19 +82,19 @@ def mergedicts(dict1, dict2):
yield (k, dict2[k])


def dict_stable_json_repr(dict_in):
def dict_stable_json_repr(dict_in: Dict) -> str:
return json.dumps(dict_in, sort_keys=True, cls=JSONEncoder)


def hashdict(dict_in, method='sha1'):
def hashdict(dict_in: Dict, method='sha1') -> str:
assert isinstance(dict_in, dict)
h = hashlib.new(method)
dict_repr = dict_stable_json_repr(dict_in)
h.update(dict_repr.encode('utf-8'))
return h.hexdigest()


def dict_values_map(f, d):
def dict_values_map(f: Callable, d: Dict) -> Dict:
"""
Simple helper to apply a function to the values of a dictionary.
Expand Down
1 change: 1 addition & 0 deletions zkpytb/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class JSONEncoder(json.JSONEncoder):
"""
A custom JSONEncoder that can handle a bit more data types than the one from stdlib.
"""

def default(self, o):
# early passthrough if it works by default
try:
Expand Down
12 changes: 8 additions & 4 deletions zkpytb/logging.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import logging
import logging.config
from pathlib import Path
from typing import Mapping, Optional, Union


mylogger = logging.getLogger('zkpytb.logging')


def setup_simple_console_and_file_logger(logger_name, logfile=True,
logdir=None, logfilename=None,
log_level='DEBUG', options=None):
def setup_simple_console_and_file_logger(logger_name: str,
logfile=True,
logdir: Optional[Union[str, Path]] = None,
logfilename: Optional[Union[str, Path]] = None,
log_level: str = 'DEBUG',
options: Optional[Mapping[str, str]] = None) -> logging.Logger:
"""
TODOC
"""
Expand All @@ -26,7 +30,7 @@ def setup_simple_console_and_file_logger(logger_name, logfile=True,
log_directory = Path(logdir)
except TypeError:
mylogger.exception('Invalid type for argument "logdir". '
'Expected "str", "bytes" or "pathlib.Path". '
'Expected "str" or "pathlib.Path". '
'Received type: {}'.format(type(logdir)))
do_file_logging = False
else:
Expand Down
26 changes: 15 additions & 11 deletions zkpytb/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Author: Marc Gallet (2017)
"""

from typing import List, Optional, Tuple

try:
import numpy as np
import pandas as pd
Expand All @@ -17,10 +19,12 @@


# More percentiles when using pd.describe()
extended_percentiles = [.01, .05, .25, .5, .75, .95, .99]
extended_percentiles: List[float] = [.01, .05, .25, .5, .75, .95, .99]


def tdescr(df_in, percentiles=None, disp=True):
def tdescr(df_in: pd.DataFrame,
percentiles: Optional[List[float]] = None,
disp: bool = True) -> pd.DataFrame:
"""
Helper function to display and return the transposition
of the output of DataFrame.describe(). This means that
Expand Down Expand Up @@ -48,7 +52,7 @@ def tdescr(df_in, percentiles=None, disp=True):
return tdescr_out


def df_query_with_ratio(df_in, query, ratio_name='ratio'):
def df_query_with_ratio(df_in: pd.DataFrame, query: str, ratio_name='ratio') -> Tuple[pd.DataFrame, float]:
"""
This function calls the .query() method on a DataFrame
and additionally computes the ratio of resulting rows
Expand All @@ -63,7 +67,7 @@ def df_query_with_ratio(df_in, query, ratio_name='ratio'):
return df_out, ratio


def remove_outliers(df_in, column, sigma=3):
def remove_outliers(df_in: pd.DataFrame, column, sigma: float = 3) -> pd.DataFrame:
"""
Very simple filter that removes outlier rows
from a DataFrame based on the distance from the
Expand All @@ -72,7 +76,7 @@ def remove_outliers(df_in, column, sigma=3):
return df_in[np.abs(df_in[column] - df_in[column].mean()) <= (sigma * df_in[column].std())]


def only_outliers(df_in, column, sigma=3):
def only_outliers(df_in: pd.DataFrame, column, sigma: float = 3) -> pd.DataFrame:
"""
Very simple filter that only keeps outlier rows
from a DataFrame based on the distance from the
Expand All @@ -81,7 +85,7 @@ def only_outliers(df_in, column, sigma=3):
return df_in[np.abs(df_in[column] - df_in[column].mean()) > (sigma * df_in[column].std())]


def move_col_to_beginning_of_df(df_in, colname):
def move_col_to_beginning_of_df(df_in: pd.DataFrame, colname: str) -> pd.DataFrame:
"""
Small helper to move a column to the beginning of the DataFrame
"""
Expand All @@ -90,19 +94,19 @@ def move_col_to_beginning_of_df(df_in, colname):
return df_in.reindex(columns=cols)


def compare_df_cols(df_list, col_list, mode=1):
def compare_df_cols(df_list: List[pd.DataFrame], col_list: List[str], mode=1) -> Optional[pd.DataFrame]:
"""
Helper to compare the values of common columns between different dataframes
Mode 1: iterate over columns as top level and DataFrames as second level
Mode 2: iterate over DataFrames as top level and columns as second level
"""
if mode == 1:
colstoconcat = [df.loc[:, col].rename(df.loc[:, col].name + '_' + str(i + 1))
colstoconcat = [df.loc[:, col].rename(str(df.loc[:, col].name) + '_' + str(i + 1))
for col in col_list
for i, df in enumerate(df_list)]
elif mode == 2:
colstoconcat = [df.loc[:, col].rename(df.loc[:, col].name + '_' + str(i + 1))
colstoconcat = [df.loc[:, col].rename(str(df.loc[:, col].name) + '_' + str(i + 1))
for i, df in enumerate(df_list)
for col in col_list]
else:
Expand Down Expand Up @@ -140,7 +144,7 @@ def _percentile(x):
return _percentile


def describe_numeric_1d(series):
def describe_numeric_1d(series: pd.Series):
"""
Patched version of pandas' .describe() function for Series
which includes the calculation of the median absolute deviation and interquartile range
Expand All @@ -149,7 +153,7 @@ def describe_numeric_1d(series):
and all other stats are set to np.nan
"""
stat_index = (['count', 'mean', 'std', 'mad', 'mad_c1', 'iqr', 'min']
+ pd.io.formats.format.format_percentiles(extended_percentiles) + ['max'])
+ pd.io.formats.format.format_percentiles(extended_percentiles) + ['max']) # type: ignore
if series.empty:
# [0, np.nan, np.nan, ..., np.nan]
d = [0] + [np.nan] * (len(stat_index) - 1)
Expand Down
3 changes: 2 additions & 1 deletion zkpytb/priorityqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ class EmptyQueueError(Exception):

class PriorityQueue:
"""Based on https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes"""

def __init__(self, name=''):
self.name = name
self.pq = [] # list of entries arranged in a heap
self.entry_finder = {} # mapping of tasks to entries
self.counter = itertools.count() # unique sequence count
self.num_tasks = 0 # track the number of tasks in the queue
self.num_tasks = 0 # track the number of tasks in the queue

def __len__(self):
return self.num_tasks
Expand Down
10 changes: 6 additions & 4 deletions zkpytb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import logging
import subprocess

from pathlib import Path
from typing import Union

mylogger = logging.getLogger('zkpytb.utils')


def hashfile(filepath, hash_method='sha256', BLOCKSIZE=65536):
def hashfile(filepath: Union[str, Path], hash_method: str = 'sha256', BLOCKSIZE: int = 65536) -> str:
"""Hash a file"""

hasher = hashlib.new(hash_method)
Expand All @@ -27,15 +29,15 @@ def hashfile(filepath, hash_method='sha256', BLOCKSIZE=65536):
return hasher.hexdigest()


def hashstring(inputstring, hash_method='sha256'):
"""Hash a file"""
def hashstring(inputstring: bytes, hash_method: str = 'sha256') -> str:
"""Hash a string"""

hasher = hashlib.new(hash_method)
hasher.update(inputstring)
return hasher.hexdigest()


def get_git_hash(rev='HEAD'):
def get_git_hash(rev: str = 'HEAD') -> str:
"""Get the git hash of the current directory"""

git_hash = ''
Expand Down

0 comments on commit 7b252d6

Please sign in to comment.