diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index c111d83..716ed66 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10'] + python-version: ['3.8', '3.9', '3.10'] name: Python ${{ matrix.python-version }} sample steps: - name: Checkout diff --git a/data/tests/fake_associations.csv b/data/tests/fake_associations.csv index b3a7646..94c855f 100644 --- a/data/tests/fake_associations.csv +++ b/data/tests/fake_associations.csv @@ -1,9 +1,9 @@ ,trans_id,account_id,trans_amount,balance_after_trans,trans_type,trans_operation,trans_k_symbol,trans_date -trans_id,1.0,0.4746435587895609,-0.21334695910231155,0.014054082101829615,0.4929328532971835,0.8514759788482767,0.8714152457284374,-0.029289257805795953 -account_id,0.4746435587895609,1.0,0.024214269962367203,0.06713217170734949,0.05505763017128429,0.10541730588406732,0.20035158555476978,-0.026417221145424448 -trans_amount,-0.21334695910231155,0.024214269962367203,1.0,0.4104949508716057,0.14384781147688352,0.47075620299312654,0.44913033115620205,-0.051358949898267885 -balance_after_trans,0.014054082101829615,0.06713217170734949,0.4104949508716057,1.0,0.10800536670088424,0.20073482566116083,0.13085896079274767,0.03224007887636768 +trans_id,1.0,0.474643558789561,-0.2133469591023116,0.0140540821018296,0.4929328532971835,0.8514759788482767,0.8714152457284374,-0.02928925780579595 +account_id,0.474643558789561,1.0,0.024214269962367224,0.0671321717073495,0.05505763017128429,0.10541730588406732,0.20035158555476978,-0.026417221145424455 +trans_amount,-0.2133469591023116,0.024214269962367224,1.0,0.41049495087160576,0.14384781147688352,0.47075620299312654,0.44913033115620205,-0.05135894989826789 +balance_after_trans,0.0140540821018296,0.0671321717073495,0.41049495087160576,1.0,0.10800536670088424,0.20073482566116083,0.13085896079274767,0.03224007887636768 trans_type,0.4929328532971835,0.05505763017128429,0.14384781147688352,0.10800536670088424,1.0,0.6997816191537789,0.5221930686175845,0.056104192679759315 trans_operation,0.8514759788482767,0.10541730588406732,0.47075620299312654,0.20073482566116083,0.6997816191537789,1.0,0.6677368630738517,0.09200128175029092 trans_k_symbol,0.8714152457284374,0.20035158555476978,0.44913033115620205,0.13085896079274767,0.5221930686175845,0.6677368630738517,1.0,0.14896840029863861 -trans_date,-0.029289257805795953,-0.026417221145424448,-0.051358949898267885,0.03224007887636768,0.056104192679759315,0.09200128175029092,0.14896840029863861,1.0 +trans_date,-0.02928925780579595,-0.026417221145424455,-0.05135894989826789,0.03224007887636768,0.056104192679759315,0.09200128175029092,0.14896840029863861,1.0 diff --git a/data/tests/fake_associations_theil.csv b/data/tests/fake_associations_theil.csv index ccc5f85..7b5e658 100644 --- a/data/tests/fake_associations_theil.csv +++ b/data/tests/fake_associations_theil.csv @@ -1,9 +1,9 @@ ,trans_id,account_id,trans_amount,balance_after_trans,trans_type,trans_operation,trans_k_symbol,trans_date -trans_id,1.0,0.4746435587895609,-0.21334695910231155,0.014054082101829615,0.4929328532971835,0.8514759788482767,0.8714152457284374,-0.029289257805795953 -account_id,0.4746435587895609,1.0,0.024214269962367203,0.06713217170734949,0.05505763017128429,0.10541730588406732,0.20035158555476978,-0.026417221145424448 -trans_amount,-0.21334695910231155,0.024214269962367203,1.0,0.4104949508716057,0.14384781147688352,0.47075620299312654,0.44913033115620205,-0.051358949898267885 -balance_after_trans,0.014054082101829615,0.06713217170734949,0.4104949508716057,1.0,0.10800536670088424,0.20073482566116083,0.13085896079274767,0.03224007887636768 -trans_type,0.4929328532971835,0.05505763017128429,0.14384781147688352,0.10800536670088424,1.0,0.4315720065489207,0.2502006364307283,0.056104192679759315 -trans_operation,0.8514759788482767,0.10541730588406732,0.47075620299312654,0.20073482566116083,0.8307948936127061,1.0,0.6143594912151971,0.09200128175029092 -trans_k_symbol,0.8714152457284374,0.20035158555476978,0.44913033115620205,0.13085896079274767,0.4751127792537729,0.6060246958534777,1.0,0.14896840029863861 -trans_date,-0.029289257805795953,-0.026417221145424448,-0.051358949898267885,0.03224007887636768,0.056104192679759315,0.09200128175029092,0.14896840029863861,1.0 +trans_id,1.0,0.474643558789561,-0.2133469591023116,0.0140540821018296,0.4929328532971835,0.8514759788482767,0.8714152457284374,-0.02928925780579595 +account_id,0.474643558789561,1.0,0.024214269962367224,0.0671321717073495,0.05505763017128429,0.10541730588406732,0.20035158555476978,-0.026417221145424455 +trans_amount,-0.2133469591023116,0.024214269962367224,1.0,0.41049495087160576,0.14384781147688352,0.47075620299312654,0.44913033115620205,-0.05135894989826789 +balance_after_trans,0.0140540821018296,0.0671321717073495,0.41049495087160576,1.0,0.10800536670088424,0.20073482566116083,0.13085896079274767,0.03224007887636768 +trans_type,0.4929328532971835,0.05505763017128429,0.14384781147688352,0.10800536670088424,1.0,0.8307948936127061,0.4751127792537729,0.056104192679759315 +trans_operation,0.8514759788482767,0.10541730588406732,0.47075620299312654,0.20073482566116083,0.4315720065489207,1.0,0.6060246958534777,0.09200128175029092 +trans_k_symbol,0.8714152457284374,0.20035158555476978,0.44913033115620205,0.13085896079274767,0.2502006364307283,0.6143594912151971,1.0,0.14896840029863861 +trans_date,-0.02928925780579595,-0.026417221145424455,-0.05135894989826789,0.03224007887636768,0.056104192679759315,0.09200128175029092,0.14896840029863861,1.0 diff --git a/data/tests/real_associations.csv b/data/tests/real_associations.csv index 76a100f..ca10882 100644 --- a/data/tests/real_associations.csv +++ b/data/tests/real_associations.csv @@ -1,9 +1,9 @@ ,trans_id,account_id,trans_amount,balance_after_trans,trans_type,trans_operation,trans_k_symbol,trans_date -trans_id,1.0,0.5228300496676785,-0.22033993565166976,0.028875713717958756,0.47428712569977377,0.8352915293991163,0.8414511312320141,0.05457374320401556 -account_id,0.5228300496676785,1.0,0.025829382873738007,0.10896692701953434,0.04827916714843414,0.08816937558698971,0.16462397499373396,0.05213217137068295 -trans_amount,-0.22033993565166976,0.025829382873738007,1.0,0.39474608564158525,0.19389971128773834,0.5311580367062264,0.4949401219382614,0.016639052876948952 -balance_after_trans,0.028875713717958756,0.10896692701953434,0.39474608564158525,1.0,0.12419170529475135,0.23006188182617215,0.1781853887887315,0.12393532301006092 +trans_id,1.0,0.5228300496676787,-0.22033993565166987,0.028875713717958783,0.47428712569977377,0.8352915293991163,0.8414511312320141,0.054573743204015576 +account_id,0.5228300496676787,1.0,0.025829382873738,0.10896692701953438,0.04827916714843414,0.08816937558698971,0.16462397499373396,0.05213217137068296 +trans_amount,-0.22033993565166987,0.025829382873738,1.0,0.39474608564158536,0.19389971128773834,0.5311580367062264,0.4949401219382614,0.01663905287694896 +balance_after_trans,0.028875713717958783,0.10896692701953438,0.39474608564158536,1.0,0.12419170529475135,0.23006188182617215,0.1781853887887315,0.12393532301006098 trans_type,0.47428712569977377,0.04827916714843414,0.19389971128773834,0.12419170529475135,1.0,0.7092204023979536,0.5164395860852621,0.031496012385482906 trans_operation,0.8352915293991163,0.08816937558698971,0.5311580367062264,0.23006188182617215,0.7092204023979536,1.0,0.6631455360238371,0.0930897099137863 trans_k_symbol,0.8414511312320141,0.16462397499373396,0.4949401219382614,0.1781853887887315,0.5164395860852621,0.6631455360238371,1.0,0.10794487843160071 -trans_date,0.05457374320401556,0.05213217137068295,0.016639052876948952,0.12393532301006092,0.031496012385482906,0.0930897099137863,0.10794487843160071,1.0 +trans_date,0.054573743204015576,0.05213217137068296,0.01663905287694896,0.12393532301006098,0.031496012385482906,0.0930897099137863,0.10794487843160071,1.0 diff --git a/data/tests/real_associations_theil.csv b/data/tests/real_associations_theil.csv index aedf241..9367ce8 100644 --- a/data/tests/real_associations_theil.csv +++ b/data/tests/real_associations_theil.csv @@ -1,9 +1,9 @@ ,trans_id,account_id,trans_amount,balance_after_trans,trans_type,trans_operation,trans_k_symbol,trans_date -trans_id,1.0,0.5228300496676785,-0.22033993565166976,0.028875713717958756,0.47428712569977377,0.8352915293991163,0.8414511312320141,0.05457374320401556 -account_id,0.5228300496676785,1.0,0.025829382873738007,0.10896692701953434,0.04827916714843414,0.08816937558698971,0.16462397499373396,0.05213217137068295 -trans_amount,-0.22033993565166976,0.025829382873738007,1.0,0.39474608564158525,0.19389971128773834,0.5311580367062264,0.4949401219382614,0.016639052876948952 -balance_after_trans,0.028875713717958756,0.10896692701953434,0.39474608564158525,1.0,0.12419170529475135,0.23006188182617215,0.1781853887887315,0.12393532301006092 -trans_type,0.47428712569977377,0.04827916714843414,0.19389971128773834,0.12419170529475135,1.0,0.4591498022313355,0.25020452459692694,0.031496012385482906 -trans_operation,0.8352915293991163,0.08816937558698971,0.5311580367062264,0.23006188182617215,0.9044741348299755,1.0,0.6177492339768428,0.0930897099137863 -trans_k_symbol,0.8414511312320141,0.16462397499373396,0.4949401219382614,0.1781853887887315,0.4820483704707144,0.6041794163417048,1.0,0.10794487843160071 -trans_date,0.05457374320401556,0.05213217137068295,0.016639052876948952,0.12393532301006092,0.031496012385482906,0.0930897099137863,0.10794487843160071,1.0 +trans_id,1.0,0.5228300496676787,-0.22033993565166987,0.028875713717958783,0.47428712569977377,0.8352915293991163,0.8414511312320141,0.054573743204015576 +account_id,0.5228300496676787,1.0,0.025829382873738,0.10896692701953438,0.04827916714843414,0.08816937558698971,0.16462397499373396,0.05213217137068296 +trans_amount,-0.22033993565166987,0.025829382873738,1.0,0.39474608564158536,0.19389971128773834,0.5311580367062264,0.4949401219382614,0.01663905287694896 +balance_after_trans,0.028875713717958783,0.10896692701953438,0.39474608564158536,1.0,0.12419170529475135,0.23006188182617215,0.1781853887887315,0.12393532301006098 +trans_type,0.47428712569977377,0.04827916714843414,0.19389971128773834,0.12419170529475135,1.0,0.9044741348299755,0.4820483704707144,0.031496012385482906 +trans_operation,0.8352915293991163,0.08816937558698971,0.5311580367062264,0.23006188182617215,0.4591498022313355,1.0,0.6041794163417048,0.0930897099137863 +trans_k_symbol,0.8414511312320141,0.16462397499373396,0.4949401219382614,0.1781853887887315,0.25020452459692694,0.6177492339768428,1.0,0.10794487843160071 +trans_date,0.054573743204015576,0.05213217137068296,0.01663905287694896,0.12393532301006098,0.031496012385482906,0.0930897099137863,0.10794487843160071,1.0 diff --git a/requirements.txt b/requirements.txt index d2e5054..8f82577 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -dython==0.5.1 +dython==0.7.3 scipy -seaborn<=0.11.1 -pandas +pandas==1.5.* +seaborn numpy matplotlib tqdm diff --git a/setup.py b/setup.py index 3ebb5a2..df07d7e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="table-evaluator", - version="v1.4.2", + version="v1.5.0", author="Bauke Brenninkmeijer", author_email="bauke.brenninkmeijer@gmail.com", description="A package to evaluate how close a synthetic data set is to real data.", @@ -14,12 +14,12 @@ url="https://github.com/Baukebrenninkmeijer/Table-Evaluator", packages=setuptools.find_packages(), install_requires=[ - 'pandas', + 'pandas==1.5.*', 'numpy', 'tqdm', 'psutil', - 'dython==0.5.1', - 'seaborn<=0.11.1', + 'dython==0.7.3', + 'seaborn', 'matplotlib', 'scikit-learn', 'scipy' @@ -27,7 +27,6 @@ classifiers=[ 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', diff --git a/table_evaluator/table_evaluator.py b/table_evaluator/table_evaluator.py index a24c617..289cd1f 100644 --- a/table_evaluator/table_evaluator.py +++ b/table_evaluator/table_evaluator.py @@ -18,7 +18,7 @@ from sklearn.exceptions import ConvergenceWarning from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier from sklearn.linear_model import Lasso, Ridge, ElasticNet, LogisticRegression -from dython.nominal import compute_associations, numerical_encoding +from dython.nominal import associations, numerical_encoding from .viz import * from .metrics import * from .notebook import visualize_notebook, isnotebook, EvaluationResult @@ -92,7 +92,7 @@ def __init__(self, real: pd.DataFrame, fake: pd.DataFrame, cat_cols=None, unique def plot_mean_std(self, fname=None): """ Class wrapper function for plotting the mean and std using `viz.plot_mean_std`. - :param fname: If not none, saves the plot with this file name. + :param fname: If not none, saves the plot with this file name. """ plot_mean_std(self.real, self.fake, fname=fname) @@ -100,7 +100,7 @@ def plot_cumsums(self, nr_cols=4, fname=None): """ Plot the cumulative sums for all columns in the real and fake dataset. Height of each row scales with the length of the labels. Each plot contains the values of a real columns and the corresponding fake column. - :param fname: If not none, saves the plot with this file name. + :param fname: If not none, saves the plot with this file name. """ nr_charts = len(self.real.columns) nr_rows = max(1, nr_charts // nr_cols) @@ -124,7 +124,7 @@ def plot_cumsums(self, nr_cols=4, fname=None): cdf(r, f, col, 'Cumsum', ax=axes[i]) plt.tight_layout(rect=[0, 0.02, 1, 0.98]) - if fname is not None: + if fname is not None: plt.savefig(fname) plt.show() @@ -133,7 +133,7 @@ def plot_distributions(self, nr_cols=3, fname=None): """ Plot the distribution plots for all columns in the real and fake dataset. Height of each row of plots scales with the length of the labels. Each plot contains the values of a real columns and the corresponding fake column. - :param fname: If not none, saves the plot with this file name. + :param fname: If not none, saves the plot with this file name. """ nr_charts = len(self.real.columns) nr_rows = max(1, nr_charts // nr_cols) @@ -175,7 +175,7 @@ def plot_distributions(self, nr_cols=3, fname=None): ax.set_xticklabels(axes[i].get_xticklabels(), rotation='vertical') plt.tight_layout(rect=[0, 0.02, 1, 0.98]) - if fname is not None: + if fname is not None: plt.savefig(fname) plt.show() @@ -213,8 +213,8 @@ def custom_cosine(a, b): else: raise ValueError(f'`how` parameter must be in [euclidean, mae, rmse]') - real_corr = compute_associations(self.real, nominal_columns=self.categorical_columns, theil_u=True) - fake_corr = compute_associations(self.fake, nominal_columns=self.categorical_columns, theil_u=True) + real_corr = associations(self.real, nominal_columns=self.categorical_columns, nom_nom_assoc='theil', compute_only=True) + fake_corr = associations(self.fake, nominal_columns=self.categorical_columns, nom_nom_assoc='theil', compute_only=True) return distance_func( real_corr.values, @@ -241,7 +241,7 @@ def plot_pca(self, fname=None): ax[0].set_title('Real data') ax[1].set_title('Fake data') - if fname is not None: + if fname is not None: plt.savefig(fname) plt.show() @@ -367,16 +367,16 @@ def score_estimators(self): def visual_evaluation(self, save_dir=None, **kwargs): """ Plot all visual evaluation metrics. Includes plotting the mean and standard deviation, cumulative sums, correlation differences and the PCA transform. - :save_dir: directory path to save images + :save_dir: directory path to save images :param kwargs: any kwargs for matplotlib. """ - if save_dir is None: + if save_dir is None: self.plot_mean_std() self.plot_cumsums() self.plot_distributions() self.plot_correlation_difference(**kwargs) - self.plot_pca() - else: + self.plot_pca() + else: save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) @@ -384,8 +384,8 @@ def visual_evaluation(self, save_dir=None, **kwargs): self.plot_cumsums(fname=save_dir/'cumsums.png') self.plot_distributions(fname=save_dir/'distributions.png') self.plot_correlation_difference(fname=save_dir/'correlation_difference.png', **kwargs) - self.plot_pca(fname=save_dir/'pca.png') - + self.plot_pca(fname=save_dir/'pca.png') + def basic_statistical_evaluation(self) -> float: """ @@ -428,7 +428,7 @@ def correlation_correlation(self) -> float: total_metrics = pd.DataFrame() for ds_name in ['real', 'fake']: ds = getattr(self, ds_name) - corr_df = compute_associations(ds, nominal_columns=self.categorical_columns, theil_u=True) + corr_df = associations(ds, nominal_columns=self.categorical_columns, nom_nom_assoc='theil', compute_only=True) values = corr_df.values values = values[~np.eye(values.shape[0], dtype=bool)].reshape(values.shape[0], -1) total_metrics[ds_name] = values.flatten() @@ -621,7 +621,7 @@ def evaluate(self, target_col: str, target_type: str = 'class', metric: str = No :param kfold: Use a 5-fold CV for the ML estimators if set to True. Train/Test on 80%/20% of the data if set to False. :param notebook: Better visualization of the results in a python notebook :param verbose: whether to print verbose logging. - :param return_outputs: Will omit printing and instead return a dictionairy with all results. + :param return_outputs: Will omit printing and instead return a dictionairy with all results. """ self.verbose = verbose if verbose is not None else self.verbose self.comparison_metric = metric if metric is not None else self.comparison_metric diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..3b92a1a --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +import sys +sys.path.append('..') \ No newline at end of file diff --git a/tests/create_test_data.ipynb b/tests/create_test_data.ipynb index a02a43c..f6f2b1b 100644 --- a/tests/create_test_data.ipynb +++ b/tests/create_test_data.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -11,16 +11,36 @@ "import numpy as np\n", "from pathlib import Path\n", "import dython\n", - "from dython.nominal import associations, numerical_encoding, compute_associations" + "from dython.nominal import associations, numerical_encoding, associations" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "assert dython.__version__ == '0.5.1', 'Dython version should be version 0.5.1'" + "assert dython.__version__ == '0.7.3', 'Dython version should be version 0.5.1'" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.5.2'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.__version__" ] }, { @@ -32,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -60,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -70,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -87,20 +107,20 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "real_assoc = compute_associations(real, nominal_columns=nominal_cols)\n", - "real_assoc_theil = compute_associations(real, nominal_columns=nominal_cols, theil_u=True)\n", + "real_assoc = associations(real, nominal_columns=nominal_cols, compute_only=True)['corr']\n", + "real_assoc_theil = associations(real, nominal_columns=nominal_cols, nom_nom_assoc='theil', compute_only=True)['corr']\n", "\n", - "fake_assoc = compute_associations(fake, nominal_columns=nominal_cols)\n", - "fake_assoc_theil = compute_associations(fake, nominal_columns=nominal_cols, theil_u=True)" + "fake_assoc = associations(fake, nominal_columns=nominal_cols, compute_only=True)['corr']\n", + "fake_assoc_theil = associations(fake, nominal_columns=nominal_cols, nom_nom_assoc='theil', compute_only=True)['corr']" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -117,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -203,8 +223,8 @@ " 0.193900\n", " 0.124192\n", " 1.000000\n", - " 0.459150\n", - " 0.250205\n", + " 0.904474\n", + " 0.482048\n", " 0.031496\n", " \n", " \n", @@ -213,9 +233,9 @@ " 0.088169\n", " 0.531158\n", " 0.230062\n", - " 0.904474\n", + " 0.459150\n", " 1.000000\n", - " 0.617749\n", + " 0.604179\n", " 0.093090\n", " \n", " \n", @@ -224,8 +244,8 @@ " 0.164624\n", " 0.494940\n", " 0.178185\n", - " 0.482048\n", - " 0.604179\n", + " 0.250205\n", + " 0.617749\n", " 1.000000\n", " 0.107945\n", " \n", @@ -260,13 +280,13 @@ "account_id 0.048279 0.088169 0.164624 0.052132 \n", "trans_amount 0.193900 0.531158 0.494940 0.016639 \n", "balance_after_trans 0.124192 0.230062 0.178185 0.123935 \n", - "trans_type 1.000000 0.459150 0.250205 0.031496 \n", - "trans_operation 0.904474 1.000000 0.617749 0.093090 \n", - "trans_k_symbol 0.482048 0.604179 1.000000 0.107945 \n", + "trans_type 1.000000 0.904474 0.482048 0.031496 \n", + "trans_operation 0.459150 1.000000 0.604179 0.093090 \n", + "trans_k_symbol 0.250205 0.617749 1.000000 0.107945 \n", "trans_date 0.031496 0.093090 0.107945 1.000000 " ] }, - "execution_count": 55, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -285,7 +305,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "base", "language": "python", "name": "python3" }, @@ -299,7 +319,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "7f01e88ea25a1ac3bb5bad9a6866e291f76248279393d8536ad2ed16e1233702" + } } }, "nbformat": 4, diff --git a/tests/metrics_test.py b/tests/metrics_test.py index 4ca7ee0..3dc3e22 100644 --- a/tests/metrics_test.py +++ b/tests/metrics_test.py @@ -1,12 +1,11 @@ import pytest import pandas as pd import numpy as np -from numpy.testing import assert_almost_equal, assert_array_almost_equal import sys sys.path.append('..') from table_evaluator.metrics import * from table_evaluator.utils import load_data -from dython.nominal import compute_associations, numerical_encoding +from dython.nominal import associations, numerical_encoding from pathlib import Path data_folder = Path('data') @@ -46,10 +45,10 @@ def test_associations(): fake_assoc_theil = pd.read_csv(test_data_folder/'fake_associations_theil.csv', index_col='Unnamed: 0') # Assert equality with saved data - pd.testing.assert_frame_equal(real_assoc, compute_associations(real, nominal_columns=cat_cols)) - pd.testing.assert_frame_equal(real_assoc_theil, compute_associations(real, nominal_columns=cat_cols, theil_u=True)) - pd.testing.assert_frame_equal(fake_assoc, compute_associations(fake, nominal_columns=cat_cols)) - pd.testing.assert_frame_equal(fake_assoc_theil, compute_associations(fake, nominal_columns=cat_cols, theil_u=True)) + pd.testing.assert_frame_equal(real_assoc, associations(real, nominal_columns=cat_cols, compute_only=True)['corr']) + pd.testing.assert_frame_equal(real_assoc_theil, associations(real, nominal_columns=cat_cols, nom_nom_assoc='theil', compute_only=True)['corr']) + pd.testing.assert_frame_equal(fake_assoc, associations(fake, nominal_columns=cat_cols, compute_only=True)['corr']) + pd.testing.assert_frame_equal(fake_assoc_theil, associations(fake, nominal_columns=cat_cols, nom_nom_assoc='theil', compute_only=True)['corr']) def test_numerical_encoding():