diff --git a/src/molearn/data/prepare.py b/src/molearn/data/prepare.py index 4bb4257..6d45bff 100644 --- a/src/molearn/data/prepare.py +++ b/src/molearn/data/prepare.py @@ -309,6 +309,41 @@ def distance_cluster( self._find_representatives(idx_idx, cluster_func.labels_) self.cluster_method = "CLUSTER_aggl" + def create_dendrogram(self, distance_threshold=50, output_path="dendrogram.png") -> None: + """ + Cluster the trajectory with hierarchical clustering (linkage) based on the RMSD between the frames + and plot a dendrogram. + Group frames that have pairwise distances less than "distance_threshold" in one cluster (default is 50). + """ + assert hasattr( + self, "traj_dists" + ), "No pairwise frame distances present - read in trajectory first" + + if self.verbose: + print("Hierarchical clustering") + + # Perform hierarchical clustering using scipy + self.linkage_matrix = linkage(self.traj_dists, method="ward") + + # Plot the dendrogram + plt.figure(figsize=(10, 7)) + dendrogram(self.linkage_matrix, no_labels=True, color_threshold=distance_threshold) + plt.title("Dendrogram of Frames") + plt.xlabel("Frame Index") + plt.ylabel("Distance") + plt.savefig('dendrogram.png') + + # Define clusters by specifying n_clusters + self.cluster_labels = fcluster(self.linkage_matrix, t=distance_threshold, criterion="distance") + + # Store cluster labels and representatives + idx_idx = np.arange(len(self.cluster_labels)) + self._find_representatives(idx_idx, self.cluster_labels) + self.cluster_method = "CLUSTER_linkage" + + if self.verbose: + print(f"Assigned {self.n_cluster} clusters.") + def pca_cluster(self, n: int = 3) -> None: """ cluster the trajectory with KMeans based on the first 5 principal components @@ -455,6 +490,81 @@ def create_trajectories( self.traj[self.frame_idx[self.test_border :]].save_dcd( os.path.join(self.outpath, f"./{self.traj_name}_test.dcd") ) + + def create_trajectories_by_dendrogram(self, test_cluster: int) -> None: + """ + Create test trajectories based on a specific cluster and create the train + trajectories from all the frames excluding the specific cluster. + + :param int test_cluster: Cluster to use as the test set. + """ + if self.test_size == 0.0: + raise ValueError("Test set is required to perform this operation.") + + if self.cluster_labels is None: + raise ValueError( + "Cluster labels are not initialized. Please run create_dendrogram first." + ) + + assert all( + [ + hasattr(self, "traj"), + hasattr(self, "traj_name"), + hasattr(self, "traj_dists"), + hasattr(self, "train_idx"), + hasattr(self, "frame_idx"), + hasattr(self, "cluster_idx"), ] + ), "Ensure trajectory is clustered first" + + # Get cluster labels + # cluster_labels = fcluster(linkage(self.traj_dists, method="ward"), t=self.n_cluster, criterion="maxclust") + self.unique_clusters = set(self.cluster_labels) + + + # Check if the test cluster is in the unique clusters + if test_cluster not in self.unique_clusters: + raise ValueError(f"Cluster {test_cluster} is not in the unique clusters.") + + # Separate indices for train and test sets based on the cluster + test_cluster_indices = np.where(self.cluster_labels == test_cluster)[0] + train_cluster_indices = np.where(self.cluster_labels != test_cluster)[0] + + if self.verbose: + print(f"Creating train and test trajectories with cluster {test_cluster} being the test set.") + + # Save train trajectory + ori_frame_train_idx = self.train_idx[train_cluster_indices] + self._save_idx( + os.path.join( + self.outpath, + f"./{self.traj_name}_train_excluding_cluster_{test_cluster}_frames.txt", + ), + ori_frame_train_idx, + ) + self.traj[ori_frame_train_idx].save_dcd( + os.path.join( + self.outpath, f"{self.traj_name}_train_excluding_cluster_{test_cluster}.dcd" + ) + ) + + # Save test trajectory + ori_frame_test_idx = self.train_idx[test_cluster_indices] + self._save_idx( + os.path.join( + self.outpath, + f"./{self.traj_name}_test_cluster_{test_cluster}_frames.txt", + ), + ori_frame_test_idx, + ) + self.traj[ori_frame_test_idx].save_dcd( + os.path.join( + self.outpath, f"{self.traj_name}_test_cluster_{test_cluster}.dcd" + ) + ) + + if self.verbose: + print("Train and test trajectories successfully created.") + if __name__ == "__main__":