Skip to content

Commit

Permalink
Merge pull request #15 from HBPMedical/feat/add-matching-widget-and-r…
Browse files Browse the repository at this point in the history
…efine-embedding-widget

refactor+feat: refine widget for embedding visualization and ad widget to visualize matching  distance results
  • Loading branch information
sebastientourbier authored May 11, 2023
2 parents a484373 + 9a6a961 commit 9024bf3
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 165 deletions.
7 changes: 7 additions & 0 deletions mip_dmp/plot/embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Module to plot the embeddings of the column names and CDE codes."""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -136,9 +138,14 @@ def pick_event_method(event):
def handle_annotations(artist, indices):
"""Add text annotations to closest point of the cursor when the mouse button was pressed."""
# Get the type of the artist that can be "cde" or "column"
artist_type = None
for k in artists.keys():
if artist == artists[k]:
artist_type = k
# If the artist type is not defined yet, return
if artist_type is None:
return
# Get the dataframe of the artist type ("cde" or "column")
artist_df = df[df["type"] == artist_type]
# For each index of the artist
for ind in indices:
Expand Down
83 changes: 83 additions & 0 deletions mip_dmp/plot/matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Module to plot the initial matching results between the input dataset columns and the target CDE codes."""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Define colors used to plot the column and CDE code embeddings
# '#4fa08b' green / '#009E73' green / '#0072B2' blue / '#FFA500' orange
COLORS = ["#0072B2", "#FFA500"]
# Set seaborn style
sns.set_style("darkgrid")
sns.set(
rc={
"axes.facecolor": "#081512",
"figure.facecolor": "#081512",
"text.color": "white",
"axes.edgecolor": "white",
"patch.edgecolor": "#081512",
"xtick.color": "white",
"ytick.color": "white",
"axes.labelcolor": "white",
"grid.color": "#4fa08b",
"axes3d.xaxis.panecolor": "#081512",
"axes3d.yaxis.panecolor": "#081512",
"axes3d.zaxis.panecolor": "#081512",
"ytick.major.pad": 8,
}
)


def heatmap_matching(
figure, matrix, inputDatasetColumns, targetCDECodes, matchingMethod
):
"""Render a heatmap of the initial matching results between the input dataset columns and the target CDE codes.
Parameters
----------
figure: matplotlib.figure.Figure
Figure to render the heatmap of the matching results.
matrix: numpy.ndarray
Similarity / distance matrix of the matching results.
inputDatasetColumns: list
List of the input dataset columns. Used as ytick labels.
targetCDECodes: list
List of the target CDE codes. Used as xtick labels.
matchingMethod: str
Matching method used to generate the similarity / distance matrix.
Used to generate the title of the figure.
"""
# Generate the figure
left, bottom, width, height = 0.2, 0.1, 0.8, 0.2
ax = figure.add_axes([left, bottom, width, height])
xtickLabels = targetCDECodes
ytickLabels = inputDatasetColumns
sns.heatmap(
matrix,
ax=ax,
xticklabels=xtickLabels,
yticklabels=ytickLabels,
annot=True,
fmt=".2f",
cmap="viridis",
)
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

distance_type = (
r"$(1 - 0.01 * \mathrm{LevenshteinRatio})$"
if matchingMethod == "fuzzy"
else "Cosine distance"
)
title = (
f"Distances for the most 10 similar CDE codes\n"
f"(method: {matchingMethod}, distance: {distance_type})"
)
ax.set_title(title)
plt.xticks(rotation=75)
plt.yticks(rotation=90)
return figure
50 changes: 46 additions & 4 deletions mip_dmp/process/matching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module that provides functions to support the matching of dataset columns to CDEs."""

# External imports
import ast
import numpy as np
import pandas as pd
from fuzzywuzzy import fuzz

Expand Down Expand Up @@ -95,13 +97,13 @@ def match_columns_to_cdes(
] # Select the nb_kept_matches first matched CDE codes.
)
)
# Store the first nb_fuzy_matches matched CDE codes in the dictionary.
# Store the first nb_fuzzy_matches matched CDE codes in the dictionary.
for i, dataset_column in enumerate(dataset.columns):
words = ast.literal_eval(matches.to_list()[i])
matched_cde_codes[dataset_column] = {
"words": matches[i][:nb_kept_matches],
"words": words,
"distances": [
fuzz.ratio(dataset_column, match)
for match in matches[i][:nb_kept_matches]
(1 - 0.01 * fuzz.ratio(dataset_column, match)) for match in words
],
"embeddings": [None] * nb_kept_matches,
}
Expand Down Expand Up @@ -296,3 +298,43 @@ def generate_initial_transform(dataset_column_values, cde_code_values, dataset_c
for dataset_column_value in dataset_column_values
}
)


def make_distance_vector(matchedCdeCodes, inputDatasetColumn):
"""Make the n closest match distance vector.
Parameters
----------
matchedCdeCodes : dict
Dictionary of the matching results in the form::
{
"inputDatasetColumn1": {
"words": ["word1", "word2", ...],
"distances": [distance1, distance2, ...],
"embeddings": [embedding1, embedding2, ...]
},
"inputDatasetColumn2": {
"words": ["word1", "word2", ...],
"distances": [distance1, distance2, ...],
"embeddings": [embedding1, embedding2, ...]
},
...
}
inputDatasetColumn : lstr
Input dataset column name.
Returns
-------
distanceVector : numpy.ndarray
Similarity/distance vector.
"""
# Get the matched CDE codes for the current input dataset column
matches = matchedCdeCodes[inputDatasetColumn]
# Initialize the similarity matrix
similarityVector = np.zeros((1, len(matches["words"])))
# Update the similarity matrix
similarityVector[0, :] = matches["distances"]
# Return the similarity matrix
return similarityVector
Binary file added mip_dmp/qt5/assets/heatmap_matching.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
43 changes: 39 additions & 4 deletions mip_dmp/qt5/components/dataset_mapper_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@
# NoEditorDelegate,
PandasTableModel,
)
from mip_dmp.qt5.components.embedding_visualization_widget import (
WordEmbeddingVisualizationWidget,
)
from mip_dmp.qt5.components.matching_visualization_widget import (
MappingMatchVisualizationWidget,
MatchingVisualizationWidget,
)

# Constants
WINDOW_NAME = "MIPDatasetMapperUI"
WINDOW_NAME = "MIP Dataset Mapper"


class MIPDatasetMapperWindow(object):
Expand Down Expand Up @@ -135,6 +138,8 @@ class MIPDatasetMapperWindow(object):
"embeddingCanvas",
"inputDatasetColumnEmbeddings",
"targetCDEsEmbeddings",
"matchingVizButton",
"matchingWidget",
]

def __init__(self, mainWindow):
Expand Down Expand Up @@ -272,13 +277,23 @@ def createToolBar(self, mainWindow):
mainWindow,
)
self.toolBar.addAction(self.mappingInitButton)
self.matchingVizButton = QAction(
QIcon(
pkg_resources.resource_filename(
"mip_dmp", os.path.join("qt5", "assets", "heatmap_matching.png")
)
),
"Visualize Column /CDE Match Distances",
mainWindow,
)
self.toolBar.addAction(self.matchingVizButton)
self.embeddingVizButton = QAction(
QIcon(
pkg_resources.resource_filename(
"mip_dmp", os.path.join("qt5", "assets", "plot_embedding.png")
)
),
"Plot Word Embedding in 3D with t-SNE",
"Visualize Word Embedding Matches in 3D (Enabled only for GloVe and Chars2Vec methods)",
mainWindow,
)
self.toolBar.addAction(self.embeddingVizButton)
Expand Down Expand Up @@ -596,10 +611,11 @@ def connectButtons(self):
self.mappingTableViewAddRowButton.clicked.connect(self.addMappingTableRow)
self.mappingTableViewDeleteRowButton.clicked.connect(self.deleteMappingTableRow)
self.embeddingVizButton.triggered.connect(self.embeddingViz)
self.matchingVizButton.triggered.connect(self.matchingViz)

def embeddingViz(self):
"""Open the embedding visualization window."""
self.embeddingWidget = MappingMatchVisualizationWidget()
self.embeddingWidget = WordEmbeddingVisualizationWidget()
print(
"Launch visualization widget with matching method: "
f"{self.initMatchingMethod.currentText()}"
Expand All @@ -623,6 +639,23 @@ def embeddingViz(self):
"Embedding visualization is not available for fuzzy matching.",
)

def matchingViz(self):
"""Open the matching visualization window."""
self.matchingWidget = MatchingVisualizationWidget(
self.inputDatasetColumns,
self.targetCDEs["code"].unique().tolist(),
self.matchedCdeCodes,
self.initMatchingMethod.currentText(),
None,
)
self.matchingWidget.set_wordcombobox_items(self.inputDatasetColumns)
print(
"Launch matching visualization widget "
f"(matching method: {self.initMatchingMethod.currentText()})"
)
self.matchingWidget.generate_heatmap_figure()
self.matchingWidget.show()

def addMappingTableRow(self):
"""Add a row to the mapping table."""
# Show a dialog to enter the dataset column name
Expand Down Expand Up @@ -932,6 +965,7 @@ def disableMappingInitItems(self):
self.mappingInitButton.setEnabled(False)
self.initMatchingMethod.setEnabled(False)
self.embeddingVizButton.setEnabled(False)
self.matchingVizButton.setEnabled(False)

def enableMappingInitItems(self):
"""Enable the mapping initialization items."""
Expand Down Expand Up @@ -1132,6 +1166,7 @@ def mappingMatch(self):
self.embeddingVizButton.setEnabled(True)
else:
self.embeddingVizButton.setEnabled(False)
self.matchingVizButton.setEnabled(True)

def selectOutputFilename(self):
"""Select the output filename."""
Expand Down
Loading

0 comments on commit 9024bf3

Please sign in to comment.