Skip to content

Commit

Permalink
Change matrix to boxplot, jitter, and annotate
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanjameskennedy committed Jul 18, 2024
1 parent 098bfd6 commit ea8319c
Showing 1 changed file with 41 additions and 20 deletions.
61 changes: 41 additions & 20 deletions jasentool/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def generate_matrix(self, sample_ids, get_cgmlst_data):

def plot_heatmap(self, distance_df, output_plot_fpath):
plt.figure(figsize=(10, 8))
sns.heatmap(distance_df, annot=True, cmap='coolwarm', center=0)
sns.heatmap(distance_df, annot=True, cmap="coolwarm", center=0)
plt.title("Differential Matrix Heatmap of cgmlst")
plt.xlabel("Jasen")
plt.ylabel("Cgviz")
Expand All @@ -103,31 +103,50 @@ def plot_barplot(self, count_dict, output_plot_fpath):
print(f"The number of alleles that aren't null for more than 1000 samples is {len(categories)}")

plt.figure(figsize=(10, 8))
bars = plt.bar(categories, counts, color='skyblue')
bars = plt.bar(categories, counts, color="skyblue")

# Add titles and labels
plt.xlabel('Alleles')
plt.ylabel('Count')
plt.title('Null Allele Count Bar Plot')
plt.xlabel("Alleles")
plt.ylabel("Count")
plt.title("Null Allele Count Bar Plot")

# Rotate the x-axis labels by 90 degrees
plt.xticks(rotation=90)

# Add value labels on top of the bars
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 1, yval, ha='center', va='bottom')
plt.text(bar.get_x() + bar.get_width()/2, yval + 1, yval, ha="center", va="bottom")

plt.tight_layout()
plt.savefig(output_plot_fpath, dpi=600)

def plot_matrix_barplot(self, df, output_plot_fpath):
def plot_matrix_boxplot(self, df, output_plot_fpath):
plt.figure(figsize=(10, 8))
plt.bar(df.index, df['sum'], color='skyblue')
plt.xlabel('Sample')
plt.ylabel('Sum of sample allele differences')
counts = list(df["sum"])
sample_ids = list(df["SampleID"])
plt.boxplot(counts)

# Add jittered data points
jitter = 0.04 # Adjust the jitter as needed
x_jitter = np.random.normal(1, jitter, size=len(counts))
plt.scatter(x_jitter, counts, alpha=0.5, color="blue")

# Set labels and title
plt.xlabel("Samples")
plt.ylabel("Sum of sample allele differences")
plt.title("Summed differential matrix of distances between pipelines' cgMLST results")
plt.xticks(rotation=90)

# Annotate outliers
for i, count in enumerate(counts):
if count > 250000 or count < -750000:
if float(x_jitter[i]) < 1:
plt.annotate(f"{sample_ids[i]}", xy=(x_jitter[i] - 0.01, count), xytext=(x_jitter[i] - 0.01, count),
horizontalalignment="right", fontsize=8)
else:
plt.annotate(f"{sample_ids[i]}", xy=(x_jitter[i] - 0.01, count), xytext=(x_jitter[i] + 0.01, count),
horizontalalignment="left", fontsize=8)

plt.tight_layout()
plt.savefig(output_plot_fpath, dpi=600)

Expand All @@ -137,22 +156,22 @@ def plot_boxplot(self, count_dict, output_plot_fpath):
plt.boxplot(counts, vert=True, patch_artist=True) # `vert=True` for vertical boxplot, `patch_artist=True` for filled boxes

# Add title and labels
plt.xlabel('Null allele count')
plt.title('Number of null alleles per sample')
plt.xlabel("Null allele count")
plt.title("Number of null alleles per sample")

min_value = np.min(counts)

# Label the minimum value on the plot
plt.annotate(f'Min: {min_value}', xy=(1, min_value), xytext=(1.05, min_value),
arrowprops=dict(facecolor='black', shrink=0.05),
horizontalalignment='left')
plt.annotate(f"Min: {min_value}", xy=(1, min_value), xytext=(1.05, min_value),
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="left")

plt.savefig(output_plot_fpath, dpi=600)

def run(self, input_files, output_fpaths, generate_matrix):
# heatmap_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "cgviz_vs_jasen_heatmap.png")
output_csv_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "cgviz_vs_jasen.csv")
barplot_matrix_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "summed_differential_matrix_barplot.png")
boxplot_matrix_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "summed_differential_matrix_boxplot.png")
barplot_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "null_alleles_barplot.png")
boxplot_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "sample_null_boxplot.png")
null_alleles_count, sample_null_count = self.get_null_allele_counts(input_files)
Expand All @@ -168,6 +187,8 @@ def run(self, input_files, output_fpaths, generate_matrix):
# self.plot_heatmap(distance_df, output_plot_fpath)
if os.path.exists(output_csv_fpath):
distance_df = pd.read_csv(output_csv_fpath, index_col=0)
distance_df['sum'] = distance_df.sum(axis=1)
filtered_df = distance_df[distance_df['sum'] >= 100]
self.plot_matrix_barplot(filtered_df, barplot_matrix_fpath)
distance_df["sum"] = distance_df.sum(axis=1)
distance_df = distance_df.reset_index()
distance_df.rename(columns={'index': 'SampleID'}, inplace=True)
filtered_df = distance_df[["SampleID", "sum"]]
self.plot_matrix_boxplot(filtered_df, boxplot_matrix_fpath)

0 comments on commit ea8319c

Please sign in to comment.