Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New feature] dendrogram clustering #31

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions src/molearn/data/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,41 @@ def distance_cluster(
self._find_representatives(idx_idx, cluster_func.labels_)
self.cluster_method = "CLUSTER_aggl"

def create_dendrogram(self, distance_threshold=50, output_path="dendrogram.png") -> None:
"""
Cluster the trajectory with hierarchical clustering (linkage) based on the RMSD between the frames
and plot a dendrogram.
Group frames that have pairwise distances less than "distance_threshold" in one cluster (default is 50).
"""
assert hasattr(
self, "traj_dists"
), "No pairwise frame distances present - read in trajectory first"

if self.verbose:
print("Hierarchical clustering")

# Perform hierarchical clustering using scipy
self.linkage_matrix = linkage(self.traj_dists, method="ward")

# Plot the dendrogram
plt.figure(figsize=(10, 7))
dendrogram(self.linkage_matrix, no_labels=True, color_threshold=distance_threshold)
plt.title("Dendrogram of Frames")
plt.xlabel("Frame Index")
plt.ylabel("Distance")
plt.savefig('dendrogram.png')

# Define clusters by specifying n_clusters
self.cluster_labels = fcluster(self.linkage_matrix, t=distance_threshold, criterion="distance")

# Store cluster labels and representatives
idx_idx = np.arange(len(self.cluster_labels))
self._find_representatives(idx_idx, self.cluster_labels)
self.cluster_method = "CLUSTER_linkage"

if self.verbose:
print(f"Assigned {self.n_cluster} clusters.")

def pca_cluster(self, n: int = 3) -> None:
"""
cluster the trajectory with KMeans based on the first 5 principal components
Expand Down Expand Up @@ -455,6 +490,81 @@ def create_trajectories(
self.traj[self.frame_idx[self.test_border :]].save_dcd(
os.path.join(self.outpath, f"./{self.traj_name}_test.dcd")
)

def create_trajectories_by_dendrogram(self, test_cluster: int) -> None:
"""
Create test trajectories based on a specific cluster and create the train
trajectories from all the frames excluding the specific cluster.

:param int test_cluster: Cluster to use as the test set.
"""
if self.test_size == 0.0:
raise ValueError("Test set is required to perform this operation.")

if self.cluster_labels is None:
raise ValueError(
"Cluster labels are not initialized. Please run create_dendrogram first."
)

assert all(
[
hasattr(self, "traj"),
hasattr(self, "traj_name"),
hasattr(self, "traj_dists"),
hasattr(self, "train_idx"),
hasattr(self, "frame_idx"),
hasattr(self, "cluster_idx"), ]
), "Ensure trajectory is clustered first"

# Get cluster labels
# cluster_labels = fcluster(linkage(self.traj_dists, method="ward"), t=self.n_cluster, criterion="maxclust")
self.unique_clusters = set(self.cluster_labels)


# Check if the test cluster is in the unique clusters
if test_cluster not in self.unique_clusters:
raise ValueError(f"Cluster {test_cluster} is not in the unique clusters.")

# Separate indices for train and test sets based on the cluster
test_cluster_indices = np.where(self.cluster_labels == test_cluster)[0]
train_cluster_indices = np.where(self.cluster_labels != test_cluster)[0]

if self.verbose:
print(f"Creating train and test trajectories with cluster {test_cluster} being the test set.")

# Save train trajectory
ori_frame_train_idx = self.train_idx[train_cluster_indices]
self._save_idx(
os.path.join(
self.outpath,
f"./{self.traj_name}_train_excluding_cluster_{test_cluster}_frames.txt",
),
ori_frame_train_idx,
)
self.traj[ori_frame_train_idx].save_dcd(
os.path.join(
self.outpath, f"{self.traj_name}_train_excluding_cluster_{test_cluster}.dcd"
)
)

# Save test trajectory
ori_frame_test_idx = self.train_idx[test_cluster_indices]
self._save_idx(
os.path.join(
self.outpath,
f"./{self.traj_name}_test_cluster_{test_cluster}_frames.txt",
),
ori_frame_test_idx,
)
self.traj[ori_frame_test_idx].save_dcd(
os.path.join(
self.outpath, f"{self.traj_name}_test_cluster_{test_cluster}.dcd"
)
)

if self.verbose:
print("Train and test trajectories successfully created.")



if __name__ == "__main__":
Expand Down
Loading