Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small improvements #16

Merged
merged 8 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Debugging
run: |
ls -la
cat Makefile
make virtualenv
- name: Install project
run: |
make virtualenv
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ jobs:
pip install -r requirements.txt
- name: Analysing the code with pylint
run: |
pylint --fail-under=4 $(git ls-files '*.py')
pylint --fail-under=6 $(git ls-files '*.py')
# pylint $(git ls-files '*.py')
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ These could include visualizing the results for a binary classifier, for which p
|:--------------------------------------------------:|:----------------------------------------------------------:|:-------------------------------------------------:|
| Calibration Curve | Classification Report | Confusion Matrix |

| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/roc_curve.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/roc_curve_bootstrap.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/y_prob_histogram.png?raw=true" width="300" alt="Your Image"> |
| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/roc_curve_bootstrap.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/pr_curve.png?raw=true" width="300" alt="Your Image"> | <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/y_prob_histogram.png?raw=true" width="300" alt="Your Image"> |
|:--------------------------------------------------:|:----------------------------------------------------------:|:-------------------------------------------------:|
| ROC Curve (AUROC) | ROC Curve (AUROC) with bootstrapping | y_prob histogram |
| ROC Curve (AUROC) with bootstrapping | Precision-Recall Curve | y_prob histogram |


| <img src="https://github.com/joshuawe/plots_and_graphs/blob/main/images/raincloud.png?raw=true" width="300" alt="Your Image"> | <img src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" width="300" height="300" alt=""> | <img src="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7" width="300" height="300" alt=""> |
Expand All @@ -82,7 +82,7 @@ Install the package via pip.
pip install plotsandgraphs
```

Alternativelynstall the package from git.
Alternatively install the package from git.
```bash
git clone https://github.com/joshuawe/plots_and_graphs
cd plots_and_graphs
Expand Down
Binary file added images/pr_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/y_prob_histogram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
78 changes: 39 additions & 39 deletions plotsandgraphs/binary_classifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from matplotlib.figure import Figure
import seaborn as sns
import numpy as np
import pandas as pd
from sklearn.metrics import (
Expand All @@ -15,9 +16,7 @@
)
from sklearn.calibration import calibration_curve
from sklearn.utils import resample
from pathlib import Path
from tqdm import tqdm
from typing import Optional


def plot_accuracy(y_true, y_pred, name="", save_fig_path=None) -> Figure:
Expand All @@ -28,27 +27,25 @@
# for t in range(max_seq_len):
# accuracy += accuracy_score( y[:,t,0].round() , y_pred[:,t] )
# accuracy = accuracy / max_seq_len
fig = plt.figure(figsize=(4, 5))
plt.bar(np.array([0]), np.array([accuracy]))

Check warning on line 31 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L30-L31

Added lines #L30 - L31 were not covered by tests
# axs[0].set_xticks(ticks=range(2))
# axs[0].set_xticklabels(["train", "test"])
plt.ylabel("Accuracy")
plt.ylim([0, 1])

Check warning on line 35 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L34-L35

Added lines #L34 - L35 were not covered by tests
# axs[0].set_xlabel('Features')
title = "Predictor model: {}".format(name)

Check warning on line 37 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L37

Added line #L37 was not covered by tests
plt.title(title)
plt.tight_layout()

if save_fig_path != None:
if save_fig_path is not None:

Check warning on line 41 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L41

Added line #L41 was not covered by tests
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")

Check warning on line 44 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L44

Added line #L44 was not covered by tests
return fig, accuracy
return fig


def plot_confusion_matrix(
y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None
) -> Figure:
def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None) -> Figure:
import matplotlib.colors as colors

# Compute the confusion matrix
Expand All @@ -57,16 +54,14 @@
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

# Create the ConfusionMatrixDisplay instance and plot it
cmd = ConfusionMatrixDisplay(
cm, display_labels=["class 0\nnegative", "class 1\npositive"]
)
cmd = ConfusionMatrixDisplay(cm, display_labels=["class 0\nnegative", "class 1\npositive"])
fig, ax = plt.subplots(figsize=(4, 4))
cmd.plot(
cmap="YlOrRd",
values_format="",
colorbar=False,
ax=ax,
text_kw={"visible": False},
# text_kw={"visible": False},
)
cmd.texts_ = []
cmd.text_ = []
Expand Down Expand Up @@ -106,7 +101,7 @@
cbar.outline.set_visible(False)
plt.tight_layout()

if save_fig_path != None:
if save_fig_path is not None:
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")
Expand All @@ -115,7 +110,7 @@


def plot_classification_report(
y_test: np.ndarray,
y_true: np.ndarray,
y_pred: np.ndarray,
title="Classification Report",
figsize=(8, 4),
Expand Down Expand Up @@ -152,32 +147,27 @@
import matplotlib as mpl
import matplotlib.colors as colors
import seaborn as sns
import pathlib

fig, ax = plt.subplots(figsize=figsize)

cmap = "YlOrRd"

clf_report = classification_report(y_test, y_pred, output_dict=True, **kwargs)
keys_to_plot = [
key
for key in clf_report.keys()
if key not in ("accuracy", "macro avg", "weighted avg")
]
clf_report = classification_report(y_true, y_pred, output_dict=True, **kwargs)
keys_to_plot = [key for key in clf_report.keys() if key not in ("accuracy", "macro avg", "weighted avg")]

Check warning on line 156 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L156

Added line #L156 was not covered by tests
df = pd.DataFrame(clf_report, columns=keys_to_plot).T
# the following line ensures that dataframe are sorted from the majority classes to the minority classes
df.sort_values(by=["support"], inplace=True)

Check warning on line 159 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L159

Added line #L159 was not covered by tests

# first, let's plot the heatmap by masking the 'support' column
rows, cols = df.shape
mask = np.zeros(df.shape)
mask[:, cols - 1] = True

Check warning on line 164 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L164

Added line #L164 was not covered by tests

bounds = np.linspace(0, 1, 11)
cmap = plt.cm.get_cmap("YlOrRd", len(bounds) + 1)
norm = colors.BoundaryNorm(bounds, cmap.N) # type: ignore[attr-defined]
cmap = plt.cm.get_cmap("YlOrRd", len(bounds) + 1) # type: ignore[assignment]
norm = colors.BoundaryNorm(bounds, cmap.N) # type: ignore[attr-defined]

Check warning on line 168 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L167-L168

Added lines #L167 - L168 were not covered by tests

ax = sns.heatmap(

Check warning on line 170 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L170

Added line #L170 was not covered by tests
df,
mask=mask,
annot=False,
Expand All @@ -190,16 +180,16 @@
linecolor="white",
)
cbar = ax.collections[0].colorbar
cbar.ax.yaxis.set_ticks_position("both")

Check warning on line 183 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L183

Added line #L183 was not covered by tests

cmap_min, cmap_max = cbar.cmap(0), cbar.cmap(1.0)

# add text annotation to heatmap
dx, dy = 0.5, 0.5
for i in range(rows):
for j in range(cols - 1):
text = f"{df.iloc[i, j]:.2%}" # if (j<cols) else f"{df.iloc[i, j]:.0f}"
ax.text(

Check warning on line 192 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L190-L192

Added lines #L190 - L192 were not covered by tests
j + dx,
i + dy,
text,
Expand All @@ -211,9 +201,9 @@

# then, let's add the support column by normalizing the colors in this column
mask = np.zeros(df.shape)
mask[:, : cols - 1] = True

Check warning on line 204 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L204

Added line #L204 was not covered by tests

ax = sns.heatmap(

Check warning on line 206 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L206

Added line #L206 was not covered by tests
df,
mask=mask,
annot=False,
Expand All @@ -229,10 +219,10 @@

cmap_min, cmap_max = cbar.cmap(0), cbar.cmap(1.0)
for i in range(rows):
j = cols - 1
text = f"{df.iloc[i, j]:.0f}" # if (j<cols) else f"{df.iloc[i, j]:.0f}"
color = (df.iloc[i, j]) / (df["support"].sum())
ax.text(

Check warning on line 225 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L222-L225

Added lines #L222 - L225 were not covered by tests
j + dx,
i + dy,
text,
Expand All @@ -243,16 +233,16 @@
)

plt.title(title)
plt.xticks(rotation=45)
plt.yticks(rotation=360)

Check warning on line 237 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L236-L237

Added lines #L236 - L237 were not covered by tests
plt.tight_layout()

if save_fig_path != None:
if save_fig_path is not None:

Check warning on line 240 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L240

Added line #L240 was not covered by tests
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")

Check warning on line 243 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L243

Added line #L243 was not covered by tests

return fig, ax

Check warning on line 245 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L245

Added line #L245 was not covered by tests


def plot_roc_curve(
Expand Down Expand Up @@ -332,9 +322,7 @@
auc_upper = np.quantile(bootstrap_aucs, CI_upper)
auc_lower = np.quantile(bootstrap_aucs, CI_lower)
label = f"{confidence_interval:.0%} CI: [{auc_lower:.2f}, {auc_upper:.2f}]"
plt.fill_between(
base_fpr, tprs_lower, tprs_upper, alpha=0.3, label=label, zorder=2
)
plt.fill_between(base_fpr, tprs_lower, tprs_upper, alpha=0.3, label=label, zorder=2)

if highlight_roc_area is True:
print(
Expand Down Expand Up @@ -366,9 +354,7 @@
return fig


def plot_calibration_curve(
y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, save_fig_path=None
) -> Figure:
def plot_calibration_curve(y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, save_fig_path=None) -> Figure:
"""
Creates calibration plot for a binary classifier and calculates the ECE.

Expand All @@ -390,26 +376,24 @@
ece : float
The expected calibration error.
"""
prob_true, prob_pred = calibration_curve(
y_true, y_prob, n_bins=n_bins, strategy="uniform"
)
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy="uniform")

Check warning on line 379 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L379

Added line #L379 was not covered by tests

# Find the number of samples in each bin
bin_counts = np.histogram(y_prob, bins=n_bins, range=(0, 1))[0]

Check warning on line 382 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L382

Added line #L382 was not covered by tests

# Calculate the weighted absolute difference (ECE)
ece = np.abs(prob_pred - prob_true) * (bin_counts / len(y_prob))
ece = ece.sum().round(2)

Check warning on line 386 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L385-L386

Added lines #L385 - L386 were not covered by tests

fig = plt.figure(figsize=(5, 5))

Check warning on line 388 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L388

Added line #L388 was not covered by tests
ax = fig.add_subplot(111)

# Evenly spaced bar locations on the x-axis and reduced bar width for spacing
bar_centers = np.linspace(0, 1, n_bins, endpoint=False) + 0.5 / n_bins
bar_width = 1.0 / n_bins # * 0.9 # 90% of the bin width to create gaps

Check warning on line 393 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L392-L393

Added lines #L392 - L393 were not covered by tests

# Plotting
ax.bar(

Check warning on line 396 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L396

Added line #L396 was not covered by tests
bar_centers,
prob_true,
width=bar_width,
Expand All @@ -420,7 +404,7 @@
linewidth=2,
label=f"True Calibration",
)
ax.bar(

Check warning on line 407 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L407

Added line #L407 was not covered by tests
bar_centers,
prob_pred - prob_true,
bottom=prob_true,
Expand All @@ -434,7 +418,7 @@
label=f"Mean ECE = {ece}",
hatch="//",
)
ax.plot(

Check warning on line 421 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L421

Added line #L421 was not covered by tests
[0, 1],
[0, 1],
linestyle="--",
Expand All @@ -444,28 +428,28 @@
)

# Labels and titles
ax.set(xlabel="Predicted probability", ylabel="True probability")

Check warning on line 431 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L431

Added line #L431 was not covered by tests
plt.xlim([0.0, 1.005])
plt.ylim([-0.01, 1.0])
ax.legend(loc="upper left", frameon=False)

Check warning on line 434 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L434

Added line #L434 was not covered by tests

# show y-grid
ax.spines[:].set_visible(False)
ax.grid(True, linestyle="-", linewidth=0.5, color="grey", alpha=0.5)

Check warning on line 438 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L438

Added line #L438 was not covered by tests
ax.set_yticks(np.arange(0, 1.1, 0.2))
ax.set_xticks(np.arange(0, 1.1, 0.2))
plt.tight_layout()

# save plot
if save_fig_path is not None:

Check warning on line 444 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L444

Added line #L444 was not covered by tests
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")

Check warning on line 447 in plotsandgraphs/binary_classifier.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/binary_classifier.py#L447

Added line #L447 was not covered by tests

return fig


def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray]=None, save_fig_path=None) -> Figure:
def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray] = None, save_fig_path=None) -> Figure:
"""
Provides a histogram for the predicted probabilities of a binary classifier. If ```y_true``` is provided, it divides the ```y_prob``` values into the two classes and plots them jointly into the same plot with different colors.

Expand All @@ -485,16 +469,32 @@
"""
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111)

if y_true is None:
ax.hist(y_prob, bins=10, alpha=0.9, edgecolor="midnightblue", linewidth=2, rwidth=1)
# same histogram as above, but with border lines
# ax.hist(y_prob, bins=10, alpha=0.5, edgecolor='black', linewidth=1.2)
else:
alpha = 0.6
ax.hist(y_prob[y_true==0], bins=10, alpha=alpha, edgecolor="midnightblue", linewidth=2, rwidth=1, label="$\\hat{y} = 0$")
ax.hist(y_prob[y_true==1], bins=10, alpha=alpha, edgecolor="darkred", linewidth=2, rwidth=1, label="$\\hat{y} = 1$")

ax.hist(
y_prob[y_true == 0],
bins=10,
alpha=alpha,
edgecolor="midnightblue",
linewidth=2,
rwidth=1,
label="$\\hat{y} = 0$",
)
ax.hist(
y_prob[y_true == 1],
bins=10,
alpha=alpha,
edgecolor="darkred",
linewidth=2,
rwidth=1,
label="$\\hat{y} = 1$",
)

plt.legend()
ax.set(xlabel="Predicted probability [-]", ylabel="Count [-]", xlim=(-0.01, 1.0))
ax.set_title("Histogram of predicted probabilities")
Expand All @@ -505,7 +505,7 @@
plt.tight_layout()

# save plot
if save_fig_path != None:
if save_fig_path is not None:
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches="tight")
Expand Down
2 changes: 1 addition & 1 deletion plotsandgraphs/compare_distributions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from typing import List, Tuple, Optional


def plot_raincloud(
Expand Down Expand Up @@ -46,15 +46,15 @@

# if colors are none, use distinct colors for each group
if colors is None:
cmap = plt.get_cmap("tab10")

Check warning on line 49 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L49

Added line #L49 was not covered by tests
colors = [mpl.colors.to_hex(cmap(i)) for i in np.linspace(0, 1, len(order))]
else:
assert len(colors) == len(order), "colors and order must be the same length"

Check warning on line 52 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L52

Added line #L52 was not covered by tests
colors = colors

# Boxplot
if show_boxplot:
bp = ax.boxplot(

Check warning on line 57 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L57

Added line #L57 was not covered by tests
[df[df[y_col] == grp][x_col].values for grp in order],
patch_artist=True,
vert=False,
Expand All @@ -63,17 +63,17 @@
)

# Customize boxplot colors
for patch, color in zip(bp["boxes"], colors):

Check warning on line 66 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L66

Added line #L66 was not covered by tests
patch.set_facecolor(color)
patch.set_alpha(0.8)

# Set median line color to black
for median in bp["medians"]:
median.set_color("black")

Check warning on line 72 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L71-L72

Added lines #L71 - L72 were not covered by tests

# Violinplot
if show_violin:
vp = ax.violinplot(

Check warning on line 76 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L76

Added line #L76 was not covered by tests
[df[df[y_col] == grp][x_col].values for grp in order],
positions=np.arange(1 + offset, len(order) + 1 + offset),
showmeans=False,
Expand All @@ -83,8 +83,8 @@
)

# Customize violinplot colors
for idx, b in enumerate(vp["bodies"]):
b.get_paths()[0].vertices[:, 1] = np.clip(

Check warning on line 87 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L86-L87

Added lines #L86 - L87 were not covered by tests
b.get_paths()[0].vertices[:, 1], idx + 1 + offset, idx + 2 + offset
)
b.set_color(colors[idx])
Expand All @@ -96,7 +96,7 @@
y = np.full(len(features), idx + 1 - offset)
jitter_amount = 0.12
y += np.random.uniform(low=-jitter_amount, high=jitter_amount, size=len(y))
plt.scatter(features, y, s=10, c=colors[idx], alpha=0.3, facecolors="none")

Check warning on line 99 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L99

Added line #L99 was not covered by tests

# Labels
plt.yticks(np.arange(1, len(order) + 1), order)
Expand All @@ -105,10 +105,10 @@
x_label = x_col
plt.xlabel(x_label)
if title:
plt.title(title + "\n")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)

Check warning on line 111 in plotsandgraphs/compare_distributions.py

View check run for this annotation

Codecov / codecov/patch

plotsandgraphs/compare_distributions.py#L108-L111

Added lines #L108 - L111 were not covered by tests
ax.xaxis.grid(True)

if x_range:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ max-line-length = 120
[tool.pylint."BASIC"]
variable-rgx = "[a-z_][a-z0-9_]{0,30}$|[a-z0-9_]+([A-Z][a-z0-9_]+)*$" # Allow snake case and camel case for variable names

[tool.pylint."MESSAGES CONTROL"]
disable = "W0621" # Allow redefining names in outer scope

[flake8]
max-line-length = 120
7 changes: 3 additions & 4 deletions src/binary_classifier.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Optional
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from matplotlib.figure import Figure
import seaborn as sns
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay, roc_curve, auc, accuracy_score, precision_recall_curve
from sklearn.calibration import calibration_curve
from sklearn.utils import resample
from pathlib import Path
from tqdm import tqdm
from typing import Optional


def plot_accuracy(y_true, y_pred, name='', save_fig_path=None) -> Figure:
Expand Down Expand Up @@ -381,7 +380,7 @@ def plot_y_prob_histogram(y_prob: np.ndarray, save_fig_path=None) -> Figure:
plt.tight_layout()

# save plot
if (save_fig_path != None):
if (save_fig_path is not None):
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches='tight')
Expand Down
17 changes: 8 additions & 9 deletions src/compare_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

def plot_raincloud(df: pd.DataFrame,
x_col: str,
y_col: str,
colors: List[str] = None,
order: List[str] = None,
title: str = None,
x_label: str = None,
x_range: Tuple[float, float] = None,
show_violin = True,
show_scatter = True,
y_col: str,
colors: List[str] = None,
order: List[str] = None,
title: str = None,
x_label: str = None,
x_range: Tuple[float, float] = None,
show_violin = True,
show_scatter = True,
show_boxplot = True):

"""
Expand Down Expand Up @@ -49,7 +49,6 @@ def plot_raincloud(df: pd.DataFrame,
colors = [mpl.colors.to_hex(cmap(i)) for i in np.linspace(0, 1, len(order))]
else:
assert len(colors) == len(order), 'colors and order must be the same length'
colors = colors

# Boxplot
if show_boxplot:
Expand Down
7 changes: 7 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os

TEST_RESULTS_PATH = os.path.join(os.path.dirname(__file__), "test_results")

# print cwd in console

# print os.path.dirname(__file__)
Loading