Skip to content

Commit

Permalink
Merge pull request #30 from EngAsal/linkage
Browse files Browse the repository at this point in the history
Add dendrogram plot
  • Loading branch information
degiacom authored Dec 16, 2024
2 parents 6bcdf39 + c5b067a commit 4a95a77
Showing 1 changed file with 110 additions and 0 deletions.
110 changes: 110 additions & 0 deletions src/molearn/data/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,41 @@ def distance_cluster(
self._find_representatives(idx_idx, cluster_func.labels_)
self.cluster_method = "CLUSTER_aggl"

def create_dendrogram(self, distance_threshold=50, output_path="dendrogram.png") -> None:
"""
Cluster the trajectory with hierarchical clustering (linkage) based on the RMSD between the frames
and plot a dendrogram.
Group frames that have pairwise distances less than "distance_threshold" in one cluster (default is 50).
"""
assert hasattr(
self, "traj_dists"
), "No pairwise frame distances present - read in trajectory first"

if self.verbose:
print("Hierarchical clustering")

# Perform hierarchical clustering using scipy
self.linkage_matrix = linkage(self.traj_dists, method="ward")

# Plot the dendrogram
plt.figure(figsize=(10, 7))
dendrogram(self.linkage_matrix, no_labels=True, color_threshold=distance_threshold)
plt.title("Dendrogram of Frames")
plt.xlabel("Frame Index")
plt.ylabel("Distance")
plt.savefig('dendrogram.png')

# Define clusters by specifying n_clusters
self.cluster_labels = fcluster(self.linkage_matrix, t=distance_threshold, criterion="distance")

# Store cluster labels and representatives
idx_idx = np.arange(len(self.cluster_labels))
self._find_representatives(idx_idx, self.cluster_labels)
self.cluster_method = "CLUSTER_linkage"

if self.verbose:
print(f"Assigned {self.n_cluster} clusters.")

def pca_cluster(self, n: int = 3) -> None:
"""
cluster the trajectory with KMeans based on the first 5 principal components
Expand Down Expand Up @@ -455,6 +490,81 @@ def create_trajectories(
self.traj[self.frame_idx[self.test_border :]].save_dcd(
os.path.join(self.outpath, f"./{self.traj_name}_test.dcd")
)

def create_trajectories_by_dendrogram(self, test_cluster: int) -> None:
"""
Create test trajectories based on a specific cluster and create the train
trajectories from all the frames excluding the specific cluster.
:param int test_cluster: Cluster to use as the test set.
"""
if self.test_size == 0.0:
raise ValueError("Test set is required to perform this operation.")

if self.cluster_labels is None:
raise ValueError(
"Cluster labels are not initialized. Please run create_dendrogram first."
)

assert all(
[
hasattr(self, "traj"),
hasattr(self, "traj_name"),
hasattr(self, "traj_dists"),
hasattr(self, "train_idx"),
hasattr(self, "frame_idx"),
hasattr(self, "cluster_idx"), ]
), "Ensure trajectory is clustered first"

# Get cluster labels
# cluster_labels = fcluster(linkage(self.traj_dists, method="ward"), t=self.n_cluster, criterion="maxclust")
self.unique_clusters = set(self.cluster_labels)


# Check if the test cluster is in the unique clusters
if test_cluster not in self.unique_clusters:
raise ValueError(f"Cluster {test_cluster} is not in the unique clusters.")

# Separate indices for train and test sets based on the cluster
test_cluster_indices = np.where(self.cluster_labels == test_cluster)[0]
train_cluster_indices = np.where(self.cluster_labels != test_cluster)[0]

if self.verbose:
print(f"Creating train and test trajectories with cluster {test_cluster} being the test set.")

# Save train trajectory
ori_frame_train_idx = self.train_idx[train_cluster_indices]
self._save_idx(
os.path.join(
self.outpath,
f"./{self.traj_name}_train_excluding_cluster_{test_cluster}_frames.txt",
),
ori_frame_train_idx,
)
self.traj[ori_frame_train_idx].save_dcd(
os.path.join(
self.outpath, f"{self.traj_name}_train_excluding_cluster_{test_cluster}.dcd"
)
)

# Save test trajectory
ori_frame_test_idx = self.train_idx[test_cluster_indices]
self._save_idx(
os.path.join(
self.outpath,
f"./{self.traj_name}_test_cluster_{test_cluster}_frames.txt",
),
ori_frame_test_idx,
)
self.traj[ori_frame_test_idx].save_dcd(
os.path.join(
self.outpath, f"{self.traj_name}_test_cluster_{test_cluster}.dcd"
)
)

if self.verbose:
print("Train and test trajectories successfully created.")



if __name__ == "__main__":
Expand Down

0 comments on commit 4a95a77

Please sign in to comment.