diff --git a/mip_dmp/plot/embedding.py b/mip_dmp/plot/embedding.py index 3ac50d7..70dbf82 100644 --- a/mip_dmp/plot/embedding.py +++ b/mip_dmp/plot/embedding.py @@ -1,3 +1,5 @@ +"""Module to plot the embeddings of the column names and CDE codes.""" + import numpy as np import pandas as pd import matplotlib.pyplot as plt @@ -136,9 +138,14 @@ def pick_event_method(event): def handle_annotations(artist, indices): """Add text annotations to closest point of the cursor when the mouse button was pressed.""" # Get the type of the artist that can be "cde" or "column" + artist_type = None for k in artists.keys(): if artist == artists[k]: artist_type = k + # If the artist type is not defined yet, return + if artist_type is None: + return + # Get the dataframe of the artist type ("cde" or "column") artist_df = df[df["type"] == artist_type] # For each index of the artist for ind in indices: diff --git a/mip_dmp/plot/matching.py b/mip_dmp/plot/matching.py new file mode 100644 index 0000000..ae1e3e5 --- /dev/null +++ b/mip_dmp/plot/matching.py @@ -0,0 +1,83 @@ +"""Module to plot the initial matching results between the input dataset columns and the target CDE codes.""" + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +# Define colors used to plot the column and CDE code embeddings +# '#4fa08b' green / '#009E73' green / '#0072B2' blue / '#FFA500' orange +COLORS = ["#0072B2", "#FFA500"] +# Set seaborn style +sns.set_style("darkgrid") +sns.set( + rc={ + "axes.facecolor": "#081512", + "figure.facecolor": "#081512", + "text.color": "white", + "axes.edgecolor": "white", + "patch.edgecolor": "#081512", + "xtick.color": "white", + "ytick.color": "white", + "axes.labelcolor": "white", + "grid.color": "#4fa08b", + "axes3d.xaxis.panecolor": "#081512", + "axes3d.yaxis.panecolor": "#081512", + "axes3d.zaxis.panecolor": "#081512", + "ytick.major.pad": 8, + } +) + + +def heatmap_matching( + figure, matrix, inputDatasetColumns, targetCDECodes, matchingMethod +): + """Render a heatmap of the initial matching results between the input dataset columns and the target CDE codes. + + Parameters + ---------- + figure: matplotlib.figure.Figure + Figure to render the heatmap of the matching results. + + matrix: numpy.ndarray + Similarity / distance matrix of the matching results. + + inputDatasetColumns: list + List of the input dataset columns. Used as ytick labels. + + targetCDECodes: list + List of the target CDE codes. Used as xtick labels. + + matchingMethod: str + Matching method used to generate the similarity / distance matrix. + Used to generate the title of the figure. + """ + # Generate the figure + left, bottom, width, height = 0.2, 0.1, 0.8, 0.2 + ax = figure.add_axes([left, bottom, width, height]) + xtickLabels = targetCDECodes + ytickLabels = inputDatasetColumns + sns.heatmap( + matrix, + ax=ax, + xticklabels=xtickLabels, + yticklabels=ytickLabels, + annot=True, + fmt=".2f", + cmap="viridis", + ) + ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False) + + distance_type = ( + r"$(1 - 0.01 * \mathrm{LevenshteinRatio})$" + if matchingMethod == "fuzzy" + else "Cosine distance" + ) + title = ( + f"Distances for the most 10 similar CDE codes\n" + f"(method: {matchingMethod}, distance: {distance_type})" + ) + ax.set_title(title) + plt.xticks(rotation=75) + plt.yticks(rotation=90) + return figure diff --git a/mip_dmp/process/matching.py b/mip_dmp/process/matching.py index 189098d..c5eec8d 100644 --- a/mip_dmp/process/matching.py +++ b/mip_dmp/process/matching.py @@ -1,6 +1,8 @@ """Module that provides functions to support the matching of dataset columns to CDEs.""" # External imports +import ast +import numpy as np import pandas as pd from fuzzywuzzy import fuzz @@ -95,13 +97,13 @@ def match_columns_to_cdes( ] # Select the nb_kept_matches first matched CDE codes. ) ) - # Store the first nb_fuzy_matches matched CDE codes in the dictionary. + # Store the first nb_fuzzy_matches matched CDE codes in the dictionary. for i, dataset_column in enumerate(dataset.columns): + words = ast.literal_eval(matches.to_list()[i]) matched_cde_codes[dataset_column] = { - "words": matches[i][:nb_kept_matches], + "words": words, "distances": [ - fuzz.ratio(dataset_column, match) - for match in matches[i][:nb_kept_matches] + (1 - 0.01 * fuzz.ratio(dataset_column, match)) for match in words ], "embeddings": [None] * nb_kept_matches, } @@ -296,3 +298,43 @@ def generate_initial_transform(dataset_column_values, cde_code_values, dataset_c for dataset_column_value in dataset_column_values } ) + + +def make_distance_vector(matchedCdeCodes, inputDatasetColumn): + """Make the n closest match distance vector. + + Parameters + ---------- + matchedCdeCodes : dict + Dictionary of the matching results in the form:: + + { + "inputDatasetColumn1": { + "words": ["word1", "word2", ...], + "distances": [distance1, distance2, ...], + "embeddings": [embedding1, embedding2, ...] + }, + "inputDatasetColumn2": { + "words": ["word1", "word2", ...], + "distances": [distance1, distance2, ...], + "embeddings": [embedding1, embedding2, ...] + }, + ... + } + + inputDatasetColumn : lstr + Input dataset column name. + + Returns + ------- + distanceVector : numpy.ndarray + Similarity/distance vector. + """ + # Get the matched CDE codes for the current input dataset column + matches = matchedCdeCodes[inputDatasetColumn] + # Initialize the similarity matrix + similarityVector = np.zeros((1, len(matches["words"]))) + # Update the similarity matrix + similarityVector[0, :] = matches["distances"] + # Return the similarity matrix + return similarityVector diff --git a/mip_dmp/qt5/assets/heatmap_matching.png b/mip_dmp/qt5/assets/heatmap_matching.png new file mode 100644 index 0000000..d08cd8a Binary files /dev/null and b/mip_dmp/qt5/assets/heatmap_matching.png differ diff --git a/mip_dmp/qt5/components/dataset_mapper_window.py b/mip_dmp/qt5/components/dataset_mapper_window.py index fd58704..08c03ab 100644 --- a/mip_dmp/qt5/components/dataset_mapper_window.py +++ b/mip_dmp/qt5/components/dataset_mapper_window.py @@ -44,12 +44,15 @@ # NoEditorDelegate, PandasTableModel, ) +from mip_dmp.qt5.components.embedding_visualization_widget import ( + WordEmbeddingVisualizationWidget, +) from mip_dmp.qt5.components.matching_visualization_widget import ( - MappingMatchVisualizationWidget, + MatchingVisualizationWidget, ) # Constants -WINDOW_NAME = "MIPDatasetMapperUI" +WINDOW_NAME = "MIP Dataset Mapper" class MIPDatasetMapperWindow(object): @@ -135,6 +138,8 @@ class MIPDatasetMapperWindow(object): "embeddingCanvas", "inputDatasetColumnEmbeddings", "targetCDEsEmbeddings", + "matchingVizButton", + "matchingWidget", ] def __init__(self, mainWindow): @@ -272,13 +277,23 @@ def createToolBar(self, mainWindow): mainWindow, ) self.toolBar.addAction(self.mappingInitButton) + self.matchingVizButton = QAction( + QIcon( + pkg_resources.resource_filename( + "mip_dmp", os.path.join("qt5", "assets", "heatmap_matching.png") + ) + ), + "Visualize Column /CDE Match Distances", + mainWindow, + ) + self.toolBar.addAction(self.matchingVizButton) self.embeddingVizButton = QAction( QIcon( pkg_resources.resource_filename( "mip_dmp", os.path.join("qt5", "assets", "plot_embedding.png") ) ), - "Plot Word Embedding in 3D with t-SNE", + "Visualize Word Embedding Matches in 3D (Enabled only for GloVe and Chars2Vec methods)", mainWindow, ) self.toolBar.addAction(self.embeddingVizButton) @@ -596,10 +611,11 @@ def connectButtons(self): self.mappingTableViewAddRowButton.clicked.connect(self.addMappingTableRow) self.mappingTableViewDeleteRowButton.clicked.connect(self.deleteMappingTableRow) self.embeddingVizButton.triggered.connect(self.embeddingViz) + self.matchingVizButton.triggered.connect(self.matchingViz) def embeddingViz(self): """Open the embedding visualization window.""" - self.embeddingWidget = MappingMatchVisualizationWidget() + self.embeddingWidget = WordEmbeddingVisualizationWidget() print( "Launch visualization widget with matching method: " f"{self.initMatchingMethod.currentText()}" @@ -623,6 +639,23 @@ def embeddingViz(self): "Embedding visualization is not available for fuzzy matching.", ) + def matchingViz(self): + """Open the matching visualization window.""" + self.matchingWidget = MatchingVisualizationWidget( + self.inputDatasetColumns, + self.targetCDEs["code"].unique().tolist(), + self.matchedCdeCodes, + self.initMatchingMethod.currentText(), + None, + ) + self.matchingWidget.set_wordcombobox_items(self.inputDatasetColumns) + print( + "Launch matching visualization widget " + f"(matching method: {self.initMatchingMethod.currentText()})" + ) + self.matchingWidget.generate_heatmap_figure() + self.matchingWidget.show() + def addMappingTableRow(self): """Add a row to the mapping table.""" # Show a dialog to enter the dataset column name @@ -932,6 +965,7 @@ def disableMappingInitItems(self): self.mappingInitButton.setEnabled(False) self.initMatchingMethod.setEnabled(False) self.embeddingVizButton.setEnabled(False) + self.matchingVizButton.setEnabled(False) def enableMappingInitItems(self): """Enable the mapping initialization items.""" @@ -1132,6 +1166,7 @@ def mappingMatch(self): self.embeddingVizButton.setEnabled(True) else: self.embeddingVizButton.setEnabled(False) + self.matchingVizButton.setEnabled(True) def selectOutputFilename(self): """Select the output filename.""" diff --git a/mip_dmp/qt5/components/embedding_visualization_widget.py b/mip_dmp/qt5/components/embedding_visualization_widget.py new file mode 100644 index 0000000..12f2e4b --- /dev/null +++ b/mip_dmp/qt5/components/embedding_visualization_widget.py @@ -0,0 +1,228 @@ +"""Class for the widget that supports the visualization of the initial automated mapping matches via embedding.""" + +# External imports +import os +import numpy as np +import matplotlib.pyplot as plt +import pkg_resources +from matplotlib.backends.backend_qt5agg import ( + FigureCanvasQTAgg as FigureCanvas, + NavigationToolbar2QT as NavigationToolbar, +) +from PySide2.QtCore import QCoreApplication +from PySide2.QtWidgets import QVBoxLayout, QWidget, QComboBox + +# Internal imports +from mip_dmp.plot.embedding import scatterplot_embeddings +from mip_dmp.process.embedding import generate_embeddings, reduce_embeddings_dimension + + +WINDOW_NAME = "Word Embedding Matches Visualization" + + +class WordEmbeddingVisualizationWidget(QWidget): + """Class for the widget that supports the visualization of the automated column / CDE code matches via embedding.""" + + def __init__(self, parent=None): + """Initialize the widget. If parent is `None`, the widget renders as a separate window.""" + super(WordEmbeddingVisualizationWidget, self).__init__(parent) + self.adjustWindow() + self.widgetLayout = QVBoxLayout() + self.setLayout(self.widgetLayout) + # Set up the combo box for selecting the dimensionality reduction method + self.dimReductionMethodComboBox = QComboBox() + self.dimReductionMethodComboBox.addItems(["tsne", "pca"]) + self.widgetLayout.addWidget(self.dimReductionMethodComboBox) + # Set up the combo box for selecting the word to visualize + # its dimensionaly reduced embedding vector in the 3D scatter plot + # with the ones of the CDE codes + self.wordComboBox = QComboBox() + self.widgetLayout.addWidget(self.wordComboBox) + # Set up the matplotlib figure and canvas + self.canvasLayout = QVBoxLayout() + self.figure = plt.figure(figsize=(6, 6)) + self.canvas = FigureCanvas(self.figure) + self.toolbar = NavigationToolbar(self.canvas, self) + self.canvasLayout.addWidget(self.canvas) + self.canvasLayout.addWidget(self.toolbar) + self.widgetLayout.addLayout(self.canvasLayout, stretch=1) + # Initialize the class attributes + self.inputDatasetColumns = list() + self.targetCDECodes = list() + self.inputDatasetColumnEmbeddings = list() + self.targetCDECodeEmbeddings = list() + self.matchedCdeCodes = dict() + self.matchingMethod = None + self.embeddings = dict() + # Connect signals to slots + self.dimReductionMethodComboBox.currentIndexChanged.connect( + self.generate_embedding_figure + ) + self.wordComboBox.currentIndexChanged.connect(self.generate_embedding_figure) + + def adjustWindow(self): + """Adjust the window size, Qt Style Sheet, and title. + + Parameters + ---------- + mainWindow : QMainWindow + The main window of the application. + """ + # Adjust the window size + # self.resize(1280, 720) + # Set the window Qt Style Sheet + styleSheetFile = pkg_resources.resource_filename( + "mip_dmp", os.path.join("qt5", "assets", "stylesheet.qss") + ) + with open(styleSheetFile, "r") as fh: + self.setStyleSheet(fh.read()) + # Set the window title + self.setWindowTitle( + QCoreApplication.translate(f"{WINDOW_NAME}", f"{WINDOW_NAME}", None) + ) + + def set_word_list(self, wordList): + """Set the list of words that can be visualized in the 3D scatter plot. + + wordList: list + List of words to visualize in the 3D scatter plot + """ + self.wordComboBox.clear() + self.wordComboBox.addItems(wordList) + + def set_matching_method(self, matchingMethod): + """Set the matching method. + + matchingMethod: str + Matching method. Can be "glove" or "chars2vec" + """ + self.matchingMethod = matchingMethod + + def generate_embeddings( + self, inputDatasetColumns: list, targetCDECodes: list, matchingMethod: str + ): + """Generate the embeddings of the columns and CDE codes. + + Set the input dataset columns (`self.inputDatasetColumns`), the target CDE codes (`self.targetCDECodes`), + the input dataset column embeddings (`self.inputDatasetColumnEmbeddings`) and the target CDE code embeddings + (`self.targetCDECodeEmbeddings`). + + The embeddings are generated using the specified matching method (`matchingMethod`). + The matching method can be "glove" or "chars2vec". + + inputDatasetColumns: list + List of the input dataset columns. + + targetCDECodes: list + List of the target CDE codes. + + matchingMethod: str + Matching method. Can be "glove" or "chars2vec" + """ + self.set_matching_method(matchingMethod) + self.inputDatasetColumns = inputDatasetColumns + self.targetCDECodes = targetCDECodes + self.inputDatasetColumnEmbeddings = generate_embeddings( + inputDatasetColumns, matchingMethod + ) + self.targetCDECodeEmbeddings = generate_embeddings( + targetCDECodes, matchingMethod + ) + + def set_embeddings( + self, + inputDatasetColumnEmbeddings: list, + inputDatasetColumns: list, + targetCDECodeEmbeddings: list, + targetCDCCodes: list, + matchedCdeCodes: dict, + matchingMethod: str, + ): + """Set the input dataset column and target CDE code embeddings. + + inputDatasetColumnEmbeddings: list + List of the input dataset column embeddings. + + inputDatasetColumns: list + List of the input dataset columns. + + targetCDECodeEmbeddings: list + List of the target CDE code embeddings. + + targetCDCCodes: list + List of the target CDE codes. + + matchedCdeCodes: dict + Dictionary of the matched CDE codes in the form:: + + { + "input_dataset_column1": { + "words": ["cde_code1", "cde_code2", ...], + "embeddings": [embedding_vector1, embedding_vector2, ...] + "distances": [distance1, distance2, ...] + }, + "input_dataset_column2": { + "words": ["cde_code1", "cde_code2", ...], + "embeddings": [embedding_vector1, embedding_vector2, ...] + "distances": [distance1, distance2, ...] + }, + ... + } + + matchingMethod: str + Matching method. Can be "glove" or "chars2vec". + """ + self.set_matching_method(matchingMethod) + self.inputDatasetColumnEmbeddings = inputDatasetColumnEmbeddings + self.inputDatasetColumns = inputDatasetColumns + self.targetCDECodeEmbeddings = targetCDECodeEmbeddings + self.targetCDECodes = targetCDCCodes + self.matchedCdeCodes = matchedCdeCodes + # Reduce embeddings dimension to 3 components via t-SNE or PCA for visualization + dim_reduction_method = self.dimReductionMethodComboBox.currentText() + x, y, z = reduce_embeddings_dimension( + self.inputDatasetColumnEmbeddings + self.targetCDECodeEmbeddings, + reduce_method=dim_reduction_method, + ) + # Set the dictionary with the embeddings and their labels, format expected + # by the scatterplot function + self.embeddings = dict( + { + "x": x, + "y": y, + "z": z, + "label": self.inputDatasetColumns + self.targetCDECodes, + "type": ( + ["column"] * len(self.inputDatasetColumns) + + ["cde"] * len(self.targetCDECodes) + ), + } + ) + + def set_wordcombobox_items(self, wordList): + """Set the items of the word combo box. + + wordList: list + List of words to visualize in the combo box of the widget + that controls the selection of the word to visualize in the + 3D scatter plot. + """ + self.wordComboBox.clear() + self.wordComboBox.addItems(wordList) + + def generate_embedding_figure(self): + """Generate 3D scatter plot showing dimensionality-reduced embedding vectors of the words.""" + + if ( + len(self.inputDatasetColumnEmbeddings) > 0 + and len(self.targetCDECodeEmbeddings) > 0 + ): + # Generate 3D scatter plot + scatterplot_embeddings( + self.figure, + self.embeddings, + self.matchedCdeCodes, + self.wordComboBox.currentText(), + ) + # Draw the figure + self.figure.canvas.draw() diff --git a/mip_dmp/qt5/components/matching_visualization_widget.py b/mip_dmp/qt5/components/matching_visualization_widget.py index 9648067..6f0b214 100644 --- a/mip_dmp/qt5/components/matching_visualization_widget.py +++ b/mip_dmp/qt5/components/matching_visualization_widget.py @@ -1,200 +1,144 @@ -"""Class for the widget that supports the visualization of the initial automated mapping matches via embedding.""" +"""Class for the widget that supports the visualization of the distances obtained by the automated mapping matches for the n most similar CDE codes.""" # External imports -import numpy as np +import os import matplotlib.pyplot as plt +import pkg_resources from matplotlib.backends.backend_qt5agg import ( FigureCanvasQTAgg as FigureCanvas, NavigationToolbar2QT as NavigationToolbar, ) + +from PySide2.QtCore import QCoreApplication from PySide2.QtWidgets import QVBoxLayout, QWidget, QComboBox # Internal imports -from mip_dmp.plot.embedding import scatterplot_embeddings -from mip_dmp.process.embedding import generate_embeddings, reduce_embeddings_dimension +from mip_dmp.plot.matching import heatmap_matching +from mip_dmp.process.matching import make_distance_vector -class MappingMatchVisualizationWidget(QWidget): - """Class for the widget that supports the visualization of the automated column / CDE code matches.""" +WINDOW_NAME = "Column /CDE Match Distance Visualization" - def __init__(self, parent=None): - """Initialize the widget. If parent is `None`, the widget renders as a separate window.""" - super(MappingMatchVisualizationWidget, self).__init__(parent) - self.widgetLayout = QVBoxLayout() - self.setLayout(self.widgetLayout) - # Set up the combo box for selecting the dimensionality reduction method - self.dimReductionMethodComboBox = QComboBox() - self.dimReductionMethodComboBox.addItems(["tsne", "pca"]) - self.widgetLayout.addWidget(self.dimReductionMethodComboBox) - # Set up the combo box for selecting the word to visualize - # its dimensionaly reduced embedding vector in the 3D scatter plot - # with the ones of the CDE codes - self.wordComboBox = QComboBox() - self.widgetLayout.addWidget(self.wordComboBox) - # Set up the matplotlib figure and canvas - self.canvasLayout = QVBoxLayout() - self.figure = plt.figure(figsize=(6, 6)) - self.canvas = FigureCanvas(self.figure) - self.toolbar = NavigationToolbar(self.canvas, self) - self.canvasLayout.addWidget(self.canvas) - self.canvasLayout.addWidget(self.toolbar) - self.widgetLayout.addLayout(self.canvasLayout, stretch=1) - # Initialize the class attributes - self.inputDatasetColumns = list() - self.targetCDECodes = list() - self.inputDatasetColumnEmbeddings = list() - self.targetCDECodeEmbeddings = list() - self.matchedCdeCodes = dict() - self.matchingMethod = None - self.embeddings = dict() - # Connect signals to slots - self.dimReductionMethodComboBox.currentIndexChanged.connect( - self.generate_embedding_figure - ) - self.wordComboBox.currentIndexChanged.connect(self.generate_embedding_figure) - def set_word_list(self, wordList): - """Set the list of words that can be visualized in the 3D scatter plot. +class MatchingVisualizationWidget(QWidget): + """Class for the widget that supports the visualization of the distances / similarity measures obtained by the automated mapping matches for the n most similar CDE codes.""" - wordList: list - List of words to visualize in the 3D scatter plot - """ - self.wordComboBox.clear() - self.wordComboBox.addItems(wordList) - - def set_matching_method(self, matchingMethod): - """Set the matching method. - - matchingMethod: str - Matching method. Can be "glove" or "chars2vec" - """ - self.matchingMethod = matchingMethod - - def generate_embeddings( - self, inputDatasetColumns: list, targetCDECodes: list, matchingMethod: str - ): - """Generate the embeddings of the columns and CDE codes. - - Set the input dataset columns (`self.inputDatasetColumns`), the target CDE codes (`self.targetCDECodes`), - the input dataset column embeddings (`self.inputDatasetColumnEmbeddings`) and the target CDE code embeddings - (`self.targetCDECodeEmbeddings`). - - The embeddings are generated using the specified matching method (`matchingMethod`). - The matching method can be "glove" or "chars2vec". - - inputDatasetColumns: list - List of the input dataset columns. - - targetCDECodes: list - List of the target CDE codes. - - matchingMethod: str - Matching method. Can be "glove" or "chars2vec" - """ - self.set_matching_method(matchingMethod) - self.inputDatasetColumns = inputDatasetColumns - self.targetCDECodes = targetCDECodes - self.inputDatasetColumnEmbeddings = generate_embeddings( - inputDatasetColumns, matchingMethod - ) - self.targetCDECodeEmbeddings = generate_embeddings( - targetCDECodes, matchingMethod - ) - - def set_embeddings( + def __init__( self, - inputDatasetColumnEmbeddings: list, - inputDatasetColumns: list, - targetCDECodeEmbeddings: list, - targetCDCCodes: list, - matchedCdeCodes: dict, - matchingMethod: str, + inputDatasetColumns=None, + targetCDECodes=None, + matchedCdeCodes=None, + matchingMethod=None, + parent=None, ): - """Set the input dataset column and target CDE code embeddings. - - inputDatasetColumnEmbeddings: list - List of the input dataset column embeddings. + """Initialize the widget. If parent is `None`, the widget renders as a separate window. inputDatasetColumns: list List of the input dataset columns. - targetCDECodeEmbeddings: list - List of the target CDE code embeddings. - - targetCDCCodes: list + targetCDECodes: list List of the target CDE codes. matchedCdeCodes: dict - Dictionary of the matched CDE codes in the form:: + Dictionary with the matched CDE codes in the following format:: { - "input_dataset_column1": { - "words": ["cde_code1", "cde_code2", ...], - "embeddings": [embedding_vector1, embedding_vector2, ...] - "distances": [distance1, distance2, ...] + "input_dataset_column_1": { + "words": [ "cde_code_1", "cde_code_2", ... ], + "distances": [ distance_1, distance_2, ... ], + "embeddings": [ embedding_1, embedding_2, ... ] }, - "input_dataset_column2": { - "words": ["cde_code1", "cde_code2", ...], - "embeddings": [embedding_vector1, embedding_vector2, ...] - "distances": [distance1, distance2, ...] + "input_dataset_column_2": { + "words": [ "cde_code_1", "cde_code_2", ... ], + "distances": [ distance_1, distance_2, ... ], + "embeddings": [ embedding_1, embedding_2, ... ] }, ... } matchingMethod: str - Matching method. Can be "glove" or "chars2vec". + String with the matching method. Can be one of the following: + - `fuzzy` + - `chars2vec` + - `glove` + """ + super(MatchingVisualizationWidget, self).__init__(parent) + self.adjustWindow() + self.widgetLayout = QVBoxLayout() + self.setLayout(self.widgetLayout) + # Set up the combo box for selecting the word to visualize + # its dimensionaly reduced embedding vector in the 3D scatter plot + # with the ones of the CDE codes + self.wordComboBox = QComboBox() + self.widgetLayout.addWidget(self.wordComboBox) + # Set up the matplotlib figure and canvas + self.canvasLayout = QVBoxLayout() + self.figure = plt.figure(figsize=(12, 12)) + self.canvas = FigureCanvas(self.figure) + self.toolbar = NavigationToolbar(self.canvas, self) + self.canvasLayout.addWidget(self.canvas) + self.canvasLayout.addWidget(self.toolbar) + self.widgetLayout.addLayout(self.canvasLayout, stretch=1) + # Initialize the class attributes (if set) + self.inputDatasetColumns = ( + inputDatasetColumns if inputDatasetColumns else list() + ) + self.targetCDECodes = targetCDECodes if targetCDECodes else list() + self.matchedCdeCodes = matchedCdeCodes if matchedCdeCodes else dict() + self.matchingMethod = matchingMethod if matchingMethod else None + # Connect the combo box to the function that generates the heatmap + self.wordComboBox.currentIndexChanged.connect(self.generate_heatmap_figure) + + def adjustWindow(self): + """Adjust the window size, Qt Style Sheet, and title. + + Parameters + ---------- + mainWindow : QMainWindow + The main window of the application. """ - self.set_matching_method(matchingMethod) - self.inputDatasetColumnEmbeddings = inputDatasetColumnEmbeddings - self.inputDatasetColumns = inputDatasetColumns - self.targetCDECodeEmbeddings = targetCDECodeEmbeddings - self.targetCDECodes = targetCDCCodes - self.matchedCdeCodes = matchedCdeCodes - # Reduce embeddings dimension to 3 components via t-SNE or PCA for visualization - dim_reduction_method = self.dimReductionMethodComboBox.currentText() - x, y, z = reduce_embeddings_dimension( - self.inputDatasetColumnEmbeddings + self.targetCDECodeEmbeddings, - reduce_method=dim_reduction_method, + # Adjust the window size + # self.resize(1280, 720) + # Set the window Qt Style Sheet + styleSheetFile = pkg_resources.resource_filename( + "mip_dmp", os.path.join("qt5", "assets", "stylesheet.qss") ) - # Set the dictionary with the embeddings and their labels, format expected - # by the scatterplot function - self.embeddings = dict( - { - "x": x, - "y": y, - "z": z, - "label": self.inputDatasetColumns + self.targetCDECodes, - "type": ( - ["column"] * len(self.inputDatasetColumns) - + ["cde"] * len(self.targetCDECodes) - ), - } + with open(styleSheetFile, "r") as fh: + self.setStyleSheet(fh.read()) + # Set the window title + self.setWindowTitle( + QCoreApplication.translate(f"{WINDOW_NAME}", f"{WINDOW_NAME}", None) ) def set_wordcombobox_items(self, wordList): """Set the items of the word combo box. + Parameters + ---------- wordList: list - List of words to visualize in the combo box of the widget - that controls the selection of the word to visualize in the - 3D scatter plot. + List of the words to add to the combo box. """ self.wordComboBox.clear() self.wordComboBox.addItems(wordList) - def generate_embedding_figure(self): - """Generate 3D scatter plot showing dimensionality-reduced embedding vectors of the words.""" - - if ( - len(self.inputDatasetColumnEmbeddings) > 0 - and len(self.targetCDECodeEmbeddings) > 0 - ): - # Generate 3D scatter plot - scatterplot_embeddings( - self.figure, - self.embeddings, - self.matchedCdeCodes, - self.wordComboBox.currentText(), - ) - # Draw the figure - self.figure.canvas.draw() + def generate_heatmap_figure(self): + """Generate a heatmap figure with seaborn that shows the similarity / distance matrix of the input dataset columns and the target CDE codes.""" + # Generate the distance vector + distanceVector = make_distance_vector( + self.matchedCdeCodes, self.wordComboBox.currentText() + ) + # Generate the heatmap + self.figure.clear() + self.figure = heatmap_matching( + self.figure, + distanceVector, + [ + self.wordComboBox.currentText() + ], # give the input dataset column only for y labels + self.matchedCdeCodes[self.wordComboBox.currentText()][ + "words" + ], # give the n most similar CDE codes for x labels + self.matchingMethod, + ) + # Draw the figure + self.figure.canvas.draw() diff --git a/mip_dmp/ui/mip_dataset_mapper_ui.py b/mip_dmp/ui/mip_dataset_mapper_ui.py index b7bcfbd..b95e242 100644 --- a/mip_dmp/ui/mip_dataset_mapper_ui.py +++ b/mip_dmp/ui/mip_dataset_mapper_ui.py @@ -39,7 +39,10 @@ def setIcon(self): def closeEvent(self, event): """Close all windows.""" - self.ui.embeddingWidget.close() + if hasattr(self.ui, "embeddingWidget"): + self.ui.embeddingWidget.close() + if hasattr(self.ui, "matchingWidget"): + self.ui.matchingWidget.close() self.close()