Skip to content

Commit

Permalink
Merge pull request #32 from Baukebrenninkmeijer/feature/upgrade-dythi…
Browse files Browse the repository at this point in the history
…ng-version
  • Loading branch information
Bauke Brenninkmeijer authored Dec 31, 2022
2 parents 0d0db9a + 7bdcf05 commit cc5a760
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10']
python-version: ['3.8', '3.9', '3.10']
name: Python ${{ matrix.python-version }} sample
steps:
- name: Checkout
Expand Down
10 changes: 5 additions & 5 deletions data/tests/fake_associations.csv
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
,trans_id,account_id,trans_amount,balance_after_trans,trans_type,trans_operation,trans_k_symbol,trans_date
trans_id,1.0,0.4746435587895609,-0.21334695910231155,0.014054082101829615,0.4929328532971835,0.8514759788482767,0.8714152457284374,-0.029289257805795953
account_id,0.4746435587895609,1.0,0.024214269962367203,0.06713217170734949,0.05505763017128429,0.10541730588406732,0.20035158555476978,-0.026417221145424448
trans_amount,-0.21334695910231155,0.024214269962367203,1.0,0.4104949508716057,0.14384781147688352,0.47075620299312654,0.44913033115620205,-0.051358949898267885
balance_after_trans,0.014054082101829615,0.06713217170734949,0.4104949508716057,1.0,0.10800536670088424,0.20073482566116083,0.13085896079274767,0.03224007887636768
trans_id,1.0,0.474643558789561,-0.2133469591023116,0.0140540821018296,0.4929328532971835,0.8514759788482767,0.8714152457284374,-0.02928925780579595
account_id,0.474643558789561,1.0,0.024214269962367224,0.0671321717073495,0.05505763017128429,0.10541730588406732,0.20035158555476978,-0.026417221145424455
trans_amount,-0.2133469591023116,0.024214269962367224,1.0,0.41049495087160576,0.14384781147688352,0.47075620299312654,0.44913033115620205,-0.05135894989826789
balance_after_trans,0.0140540821018296,0.0671321717073495,0.41049495087160576,1.0,0.10800536670088424,0.20073482566116083,0.13085896079274767,0.03224007887636768
trans_type,0.4929328532971835,0.05505763017128429,0.14384781147688352,0.10800536670088424,1.0,0.6997816191537789,0.5221930686175845,0.056104192679759315
trans_operation,0.8514759788482767,0.10541730588406732,0.47075620299312654,0.20073482566116083,0.6997816191537789,1.0,0.6677368630738517,0.09200128175029092
trans_k_symbol,0.8714152457284374,0.20035158555476978,0.44913033115620205,0.13085896079274767,0.5221930686175845,0.6677368630738517,1.0,0.14896840029863861
trans_date,-0.029289257805795953,-0.026417221145424448,-0.051358949898267885,0.03224007887636768,0.056104192679759315,0.09200128175029092,0.14896840029863861,1.0
trans_date,-0.02928925780579595,-0.026417221145424455,-0.05135894989826789,0.03224007887636768,0.056104192679759315,0.09200128175029092,0.14896840029863861,1.0
16 changes: 8 additions & 8 deletions data/tests/fake_associations_theil.csv
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
,trans_id,account_id,trans_amount,balance_after_trans,trans_type,trans_operation,trans_k_symbol,trans_date
trans_id,1.0,0.4746435587895609,-0.21334695910231155,0.014054082101829615,0.4929328532971835,0.8514759788482767,0.8714152457284374,-0.029289257805795953
account_id,0.4746435587895609,1.0,0.024214269962367203,0.06713217170734949,0.05505763017128429,0.10541730588406732,0.20035158555476978,-0.026417221145424448
trans_amount,-0.21334695910231155,0.024214269962367203,1.0,0.4104949508716057,0.14384781147688352,0.47075620299312654,0.44913033115620205,-0.051358949898267885
balance_after_trans,0.014054082101829615,0.06713217170734949,0.4104949508716057,1.0,0.10800536670088424,0.20073482566116083,0.13085896079274767,0.03224007887636768
trans_type,0.4929328532971835,0.05505763017128429,0.14384781147688352,0.10800536670088424,1.0,0.4315720065489207,0.2502006364307283,0.056104192679759315
trans_operation,0.8514759788482767,0.10541730588406732,0.47075620299312654,0.20073482566116083,0.8307948936127061,1.0,0.6143594912151971,0.09200128175029092
trans_k_symbol,0.8714152457284374,0.20035158555476978,0.44913033115620205,0.13085896079274767,0.4751127792537729,0.6060246958534777,1.0,0.14896840029863861
trans_date,-0.029289257805795953,-0.026417221145424448,-0.051358949898267885,0.03224007887636768,0.056104192679759315,0.09200128175029092,0.14896840029863861,1.0
trans_id,1.0,0.474643558789561,-0.2133469591023116,0.0140540821018296,0.4929328532971835,0.8514759788482767,0.8714152457284374,-0.02928925780579595
account_id,0.474643558789561,1.0,0.024214269962367224,0.0671321717073495,0.05505763017128429,0.10541730588406732,0.20035158555476978,-0.026417221145424455
trans_amount,-0.2133469591023116,0.024214269962367224,1.0,0.41049495087160576,0.14384781147688352,0.47075620299312654,0.44913033115620205,-0.05135894989826789
balance_after_trans,0.0140540821018296,0.0671321717073495,0.41049495087160576,1.0,0.10800536670088424,0.20073482566116083,0.13085896079274767,0.03224007887636768
trans_type,0.4929328532971835,0.05505763017128429,0.14384781147688352,0.10800536670088424,1.0,0.8307948936127061,0.4751127792537729,0.056104192679759315
trans_operation,0.8514759788482767,0.10541730588406732,0.47075620299312654,0.20073482566116083,0.4315720065489207,1.0,0.6060246958534777,0.09200128175029092
trans_k_symbol,0.8714152457284374,0.20035158555476978,0.44913033115620205,0.13085896079274767,0.2502006364307283,0.6143594912151971,1.0,0.14896840029863861
trans_date,-0.02928925780579595,-0.026417221145424455,-0.05135894989826789,0.03224007887636768,0.056104192679759315,0.09200128175029092,0.14896840029863861,1.0
10 changes: 5 additions & 5 deletions data/tests/real_associations.csv
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
,trans_id,account_id,trans_amount,balance_after_trans,trans_type,trans_operation,trans_k_symbol,trans_date
trans_id,1.0,0.5228300496676785,-0.22033993565166976,0.028875713717958756,0.47428712569977377,0.8352915293991163,0.8414511312320141,0.05457374320401556
account_id,0.5228300496676785,1.0,0.025829382873738007,0.10896692701953434,0.04827916714843414,0.08816937558698971,0.16462397499373396,0.05213217137068295
trans_amount,-0.22033993565166976,0.025829382873738007,1.0,0.39474608564158525,0.19389971128773834,0.5311580367062264,0.4949401219382614,0.016639052876948952
balance_after_trans,0.028875713717958756,0.10896692701953434,0.39474608564158525,1.0,0.12419170529475135,0.23006188182617215,0.1781853887887315,0.12393532301006092
trans_id,1.0,0.5228300496676787,-0.22033993565166987,0.028875713717958783,0.47428712569977377,0.8352915293991163,0.8414511312320141,0.054573743204015576
account_id,0.5228300496676787,1.0,0.025829382873738,0.10896692701953438,0.04827916714843414,0.08816937558698971,0.16462397499373396,0.05213217137068296
trans_amount,-0.22033993565166987,0.025829382873738,1.0,0.39474608564158536,0.19389971128773834,0.5311580367062264,0.4949401219382614,0.01663905287694896
balance_after_trans,0.028875713717958783,0.10896692701953438,0.39474608564158536,1.0,0.12419170529475135,0.23006188182617215,0.1781853887887315,0.12393532301006098
trans_type,0.47428712569977377,0.04827916714843414,0.19389971128773834,0.12419170529475135,1.0,0.7092204023979536,0.5164395860852621,0.031496012385482906
trans_operation,0.8352915293991163,0.08816937558698971,0.5311580367062264,0.23006188182617215,0.7092204023979536,1.0,0.6631455360238371,0.0930897099137863
trans_k_symbol,0.8414511312320141,0.16462397499373396,0.4949401219382614,0.1781853887887315,0.5164395860852621,0.6631455360238371,1.0,0.10794487843160071
trans_date,0.05457374320401556,0.05213217137068295,0.016639052876948952,0.12393532301006092,0.031496012385482906,0.0930897099137863,0.10794487843160071,1.0
trans_date,0.054573743204015576,0.05213217137068296,0.01663905287694896,0.12393532301006098,0.031496012385482906,0.0930897099137863,0.10794487843160071,1.0
16 changes: 8 additions & 8 deletions data/tests/real_associations_theil.csv
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
,trans_id,account_id,trans_amount,balance_after_trans,trans_type,trans_operation,trans_k_symbol,trans_date
trans_id,1.0,0.5228300496676785,-0.22033993565166976,0.028875713717958756,0.47428712569977377,0.8352915293991163,0.8414511312320141,0.05457374320401556
account_id,0.5228300496676785,1.0,0.025829382873738007,0.10896692701953434,0.04827916714843414,0.08816937558698971,0.16462397499373396,0.05213217137068295
trans_amount,-0.22033993565166976,0.025829382873738007,1.0,0.39474608564158525,0.19389971128773834,0.5311580367062264,0.4949401219382614,0.016639052876948952
balance_after_trans,0.028875713717958756,0.10896692701953434,0.39474608564158525,1.0,0.12419170529475135,0.23006188182617215,0.1781853887887315,0.12393532301006092
trans_type,0.47428712569977377,0.04827916714843414,0.19389971128773834,0.12419170529475135,1.0,0.4591498022313355,0.25020452459692694,0.031496012385482906
trans_operation,0.8352915293991163,0.08816937558698971,0.5311580367062264,0.23006188182617215,0.9044741348299755,1.0,0.6177492339768428,0.0930897099137863
trans_k_symbol,0.8414511312320141,0.16462397499373396,0.4949401219382614,0.1781853887887315,0.4820483704707144,0.6041794163417048,1.0,0.10794487843160071
trans_date,0.05457374320401556,0.05213217137068295,0.016639052876948952,0.12393532301006092,0.031496012385482906,0.0930897099137863,0.10794487843160071,1.0
trans_id,1.0,0.5228300496676787,-0.22033993565166987,0.028875713717958783,0.47428712569977377,0.8352915293991163,0.8414511312320141,0.054573743204015576
account_id,0.5228300496676787,1.0,0.025829382873738,0.10896692701953438,0.04827916714843414,0.08816937558698971,0.16462397499373396,0.05213217137068296
trans_amount,-0.22033993565166987,0.025829382873738,1.0,0.39474608564158536,0.19389971128773834,0.5311580367062264,0.4949401219382614,0.01663905287694896
balance_after_trans,0.028875713717958783,0.10896692701953438,0.39474608564158536,1.0,0.12419170529475135,0.23006188182617215,0.1781853887887315,0.12393532301006098
trans_type,0.47428712569977377,0.04827916714843414,0.19389971128773834,0.12419170529475135,1.0,0.9044741348299755,0.4820483704707144,0.031496012385482906
trans_operation,0.8352915293991163,0.08816937558698971,0.5311580367062264,0.23006188182617215,0.4591498022313355,1.0,0.6041794163417048,0.0930897099137863
trans_k_symbol,0.8414511312320141,0.16462397499373396,0.4949401219382614,0.1781853887887315,0.25020452459692694,0.6177492339768428,1.0,0.10794487843160071
trans_date,0.054573743204015576,0.05213217137068296,0.01663905287694896,0.12393532301006098,0.031496012385482906,0.0930897099137863,0.10794487843160071,1.0
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
dython==0.5.1
dython==0.7.3
scipy
seaborn<=0.11.1
pandas
pandas==1.5.*
seaborn
numpy
matplotlib
tqdm
Expand Down
9 changes: 4 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="table-evaluator",
version="v1.4.2",
version="v1.5.0",
author="Bauke Brenninkmeijer",
author_email="[email protected]",
description="A package to evaluate how close a synthetic data set is to real data.",
Expand All @@ -14,20 +14,19 @@
url="https://github.com/Baukebrenninkmeijer/Table-Evaluator",
packages=setuptools.find_packages(),
install_requires=[
'pandas',
'pandas==1.5.*',
'numpy',
'tqdm',
'psutil',
'dython==0.5.1',
'seaborn<=0.11.1',
'dython==0.7.3',
'seaborn',
'matplotlib',
'scikit-learn',
'scipy'
],
classifiers=[
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
Expand Down
34 changes: 17 additions & 17 deletions table_evaluator/table_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn.exceptions import ConvergenceWarning
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.linear_model import Lasso, Ridge, ElasticNet, LogisticRegression
from dython.nominal import compute_associations, numerical_encoding
from dython.nominal import associations, numerical_encoding
from .viz import *
from .metrics import *
from .notebook import visualize_notebook, isnotebook, EvaluationResult
Expand Down Expand Up @@ -92,15 +92,15 @@ def __init__(self, real: pd.DataFrame, fake: pd.DataFrame, cat_cols=None, unique
def plot_mean_std(self, fname=None):
"""
Class wrapper function for plotting the mean and std using `viz.plot_mean_std`.
:param fname: If not none, saves the plot with this file name.
:param fname: If not none, saves the plot with this file name.
"""
plot_mean_std(self.real, self.fake, fname=fname)

def plot_cumsums(self, nr_cols=4, fname=None):
"""
Plot the cumulative sums for all columns in the real and fake dataset. Height of each row scales with the length of the labels. Each plot contains the
values of a real columns and the corresponding fake column.
:param fname: If not none, saves the plot with this file name.
:param fname: If not none, saves the plot with this file name.
"""
nr_charts = len(self.real.columns)
nr_rows = max(1, nr_charts // nr_cols)
Expand All @@ -124,7 +124,7 @@ def plot_cumsums(self, nr_cols=4, fname=None):
cdf(r, f, col, 'Cumsum', ax=axes[i])
plt.tight_layout(rect=[0, 0.02, 1, 0.98])

if fname is not None:
if fname is not None:
plt.savefig(fname)

plt.show()
Expand All @@ -133,7 +133,7 @@ def plot_distributions(self, nr_cols=3, fname=None):
"""
Plot the distribution plots for all columns in the real and fake dataset. Height of each row of plots scales with the length of the labels. Each plot
contains the values of a real columns and the corresponding fake column.
:param fname: If not none, saves the plot with this file name.
:param fname: If not none, saves the plot with this file name.
"""
nr_charts = len(self.real.columns)
nr_rows = max(1, nr_charts // nr_cols)
Expand Down Expand Up @@ -175,7 +175,7 @@ def plot_distributions(self, nr_cols=3, fname=None):
ax.set_xticklabels(axes[i].get_xticklabels(), rotation='vertical')
plt.tight_layout(rect=[0, 0.02, 1, 0.98])

if fname is not None:
if fname is not None:
plt.savefig(fname)

plt.show()
Expand Down Expand Up @@ -213,8 +213,8 @@ def custom_cosine(a, b):
else:
raise ValueError(f'`how` parameter must be in [euclidean, mae, rmse]')

real_corr = compute_associations(self.real, nominal_columns=self.categorical_columns, theil_u=True)
fake_corr = compute_associations(self.fake, nominal_columns=self.categorical_columns, theil_u=True)
real_corr = associations(self.real, nominal_columns=self.categorical_columns, nom_nom_assoc='theil', compute_only=True)
fake_corr = associations(self.fake, nominal_columns=self.categorical_columns, nom_nom_assoc='theil', compute_only=True)

return distance_func(
real_corr.values,
Expand All @@ -241,7 +241,7 @@ def plot_pca(self, fname=None):
ax[0].set_title('Real data')
ax[1].set_title('Fake data')

if fname is not None:
if fname is not None:
plt.savefig(fname)

plt.show()
Expand Down Expand Up @@ -367,25 +367,25 @@ def score_estimators(self):
def visual_evaluation(self, save_dir=None, **kwargs):
"""
Plot all visual evaluation metrics. Includes plotting the mean and standard deviation, cumulative sums, correlation differences and the PCA transform.
:save_dir: directory path to save images
:save_dir: directory path to save images
:param kwargs: any kwargs for matplotlib.
"""
if save_dir is None:
if save_dir is None:
self.plot_mean_std()
self.plot_cumsums()
self.plot_distributions()
self.plot_correlation_difference(**kwargs)
self.plot_pca()
else:
self.plot_pca()
else:
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)

self.plot_mean_std(fname=save_dir/'mean_std.png')
self.plot_cumsums(fname=save_dir/'cumsums.png')
self.plot_distributions(fname=save_dir/'distributions.png')
self.plot_correlation_difference(fname=save_dir/'correlation_difference.png', **kwargs)
self.plot_pca(fname=save_dir/'pca.png')
self.plot_pca(fname=save_dir/'pca.png')


def basic_statistical_evaluation(self) -> float:
"""
Expand Down Expand Up @@ -428,7 +428,7 @@ def correlation_correlation(self) -> float:
total_metrics = pd.DataFrame()
for ds_name in ['real', 'fake']:
ds = getattr(self, ds_name)
corr_df = compute_associations(ds, nominal_columns=self.categorical_columns, theil_u=True)
corr_df = associations(ds, nominal_columns=self.categorical_columns, nom_nom_assoc='theil', compute_only=True)
values = corr_df.values
values = values[~np.eye(values.shape[0], dtype=bool)].reshape(values.shape[0], -1)
total_metrics[ds_name] = values.flatten()
Expand Down Expand Up @@ -621,7 +621,7 @@ def evaluate(self, target_col: str, target_type: str = 'class', metric: str = No
:param kfold: Use a 5-fold CV for the ML estimators if set to True. Train/Test on 80%/20% of the data if set to False.
:param notebook: Better visualization of the results in a python notebook
:param verbose: whether to print verbose logging.
:param return_outputs: Will omit printing and instead return a dictionairy with all results.
:param return_outputs: Will omit printing and instead return a dictionairy with all results.
"""
self.verbose = verbose if verbose is not None else self.verbose
self.comparison_metric = metric if metric is not None else self.comparison_metric
Expand Down
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import sys
sys.path.append('..')
Loading

0 comments on commit cc5a760

Please sign in to comment.