-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from Degiacomi-Lab/path_gen_example
added example to generate structures along path
- Loading branch information
Showing
1 changed file
with
139 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import os | ||
import numpy as np | ||
import torch | ||
import sys | ||
import matplotlib.pyplot as plt | ||
|
||
sys.path.insert(0, os.path.join(os.path.abspath(os.pardir), "src")) | ||
from molearn.models.foldingnet import AutoEncoder | ||
from molearn.analysis import MolearnAnalysis, get_path, oversample | ||
from molearn.data import PDBData | ||
|
||
|
||
def main(): | ||
# Note: running the code below within a function is necessary to ensure that | ||
# multiprocessing (used to calculate DOPE and Ramachandran) runs correctly | ||
|
||
print("> Loading network parameters...") | ||
fname = f"xbb_foldingnet_checkpoints{os.sep}checkpoint_epoch208_loss-4.205589803059896.ckpt" | ||
# if GPU is available we will use the GPU else the CPU | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
checkpoint = torch.load(fname, map_location=device) | ||
net = AutoEncoder(**checkpoint["network_kwargs"]) | ||
net.load_state_dict(checkpoint["model_state_dict"]) | ||
|
||
# the network is currently on CPU. If GPU is available, move it there | ||
if torch.cuda.is_available(): | ||
net.to(device) | ||
|
||
print("> Loading training data...") | ||
|
||
MA = MolearnAnalysis() | ||
MA.set_network(net) | ||
|
||
# increasing the batch size makes encoding/decoding operations faster, | ||
# but more memory demanding | ||
MA.batch_size = 4 | ||
|
||
# increasing processes makes DOPE and Ramachandran scores calculations faster, | ||
# but more more memory demanding | ||
MA.processes = 2 | ||
|
||
# what follows is a method to re-create the training and test set | ||
# by defining the manual see and loading the dataset in the same order as when | ||
# the neural network was trained, the same train-test split will be obtained | ||
data = PDBData() | ||
data.import_pdb( | ||
"./clustered/MurDopen_CLUSTER_aggl_train.dcd", | ||
"./clustered/MurDopen_NEW_TOPO.pdb", | ||
) | ||
data.fix_terminal() | ||
data.atomselect(atoms=["CA", "C", "N", "CB", "O"]) | ||
data.prepare_dataset() | ||
data_train, data_test = data.split(manual_seed=25) | ||
|
||
# store the training and test set in the MolearnAnalysis instance | ||
# the second parameter of the following commands can be both a PDBData instance | ||
# or a path to a multi-PDB file | ||
MA.set_dataset("training", data_train) | ||
MA.set_dataset("test", data_test) | ||
|
||
print("> generating error landscape") | ||
# build a 50x50 grid. By default, it will be 10% larger than the region occupied | ||
# by all loaded datasets | ||
grid_side_len = 50 | ||
MA.setup_grid(grid_side_len) | ||
landscape_err_latent, landscape_err_3d, xaxis, yaxis = MA.scan_error() | ||
|
||
# OPTIONAL START - only for demonstration - use own start and end points for path | ||
# sort landscape by error | ||
flat_sort = landscape_err_latent.ravel().argsort() | ||
# flat index of the lowest error point on grid | ||
start = flat_sort[0] | ||
# how many structures to test to find the most distant point in high quality (low error) grid points | ||
n_test = 200 | ||
# end of path | ||
end = flat_sort[:n_test][np.argmax(np.abs(flat_sort[:n_test] - start))] | ||
start_idx = np.unravel_index(start, landscape_err_latent.shape) | ||
end_idx = np.unravel_index(end, landscape_err_latent.shape) | ||
# OPTIONAL END | ||
|
||
# linear interpolation | ||
# use your true start and endpoint latent space coordinates as start and end | ||
latent_path = oversample( | ||
np.asarray( | ||
[ | ||
# start coordinates | ||
[xaxis[start_idx[1]], yaxis[start_idx[0]]], | ||
# end coordinates | ||
[xaxis[end_idx[1]], yaxis[end_idx[0]]], | ||
] | ||
) | ||
) | ||
# use A* to find the best path between start and end | ||
latent_path_astar = get_path( | ||
np.asarray(start_idx)[::-1], | ||
np.asarray(end_idx)[::-1], | ||
landscape_err_latent, | ||
xaxis, | ||
yaxis, | ||
)[0] | ||
|
||
# OPTIONAL START plotting of landscape with start end and path | ||
fig, ax = plt.subplots() | ||
sm = ax.pcolormesh(xaxis, yaxis, landscape_err_latent) | ||
cbar = fig.colorbar(sm, orientation="vertical") | ||
cbar.ax.set_ylabel("RMSD in Å", rotation=270) | ||
cbar.ax.get_yaxis().labelpad = 15 | ||
ax.scatter( | ||
latent_path[:, 0], latent_path[:, 1], label="direct", s=3, color="firebrick" | ||
) | ||
ax.scatter( | ||
latent_path_astar[:, 0], | ||
latent_path_astar[:, 1], | ||
label="astar", | ||
s=3, | ||
color="forestgreen", | ||
) | ||
ax.scatter( | ||
xaxis[start_idx[1]], yaxis[start_idx[0]], label="start", s=10, color="black" | ||
) | ||
ax.scatter(xaxis[end_idx[1]], yaxis[end_idx[0]], label="end", s=10, color="yellow") | ||
plt.legend() | ||
plt.savefig("Error_grid_sampling.png", dpi=150) | ||
# OPTIONAL END | ||
|
||
# generating new structures | ||
# !!! relax=True will only work when trained on all atoms !!! | ||
latent_path = latent_path.reshape(1, -1, 2) | ||
latent_path_astar = latent_path_astar.reshape(1, -1, 2) | ||
if not os.path.isdir("newly_generated_structs_linear"): | ||
os.mkdir("newly_generated_structs_linear") | ||
if not os.path.isdir("newly_generated_structs_astar"): | ||
os.mkdir("newly_generated_structs_astar") | ||
MA.generate(latent_path, "newly_generated_structs_linear", relax=False) | ||
MA.generate(latent_path_astar, "newly_generated_structs_astar", relax=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |