-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Molecule graph environment #139
base: main
Are you sure you want to change the base?
Conversation
…ni_proxy # Conflicts: # gflownet/envs/alaninedipeptide_mixture.py # gflownet/utils/molecule/datasets.py # gflownet/utils/molecule/dgl_conformer.py # gflownet/utils/molecule/old_conformer.py # gflownet/utils/molecule/rdkit_conformer.py # gflownet/utils/molecule/torsions.py # tests/gflownet/utils/molecule/test_torsions.py
Merge main into molecule_graph_env
Using TorchANI proxy for AlanineDipeptide
# Taken from https://pyxtal.readthedocs.io/en/latest/_modules/pyxtal/molecule.html. | ||
|
||
from operator import itemgetter | ||
|
||
|
||
def find_rotor_from_smile(smile): | ||
""" | ||
Find the positions of rotatable bonds in the molecule. | ||
""" | ||
|
||
def cleaner(list_to_clean, neighbors): | ||
""" | ||
Remove duplicate torsion from a list of atom index tuples. | ||
""" | ||
|
||
for_remove = [] | ||
for x in reversed(range(len(list_to_clean))): | ||
ix0 = itemgetter(0)(list_to_clean[x]) | ||
ix3 = itemgetter(3)(list_to_clean[x]) | ||
# for i-j-k-l, we don't want i, l are the ending members | ||
# here C-C-S=O is not a good choice since O is only 1-coordinated | ||
if neighbors[ix0] > 1 and neighbors[ix3] > 1: | ||
for y in reversed(range(x)): | ||
ix1 = itemgetter(1)(list_to_clean[x]) | ||
ix2 = itemgetter(2)(list_to_clean[x]) | ||
iy1 = itemgetter(1)(list_to_clean[y]) | ||
iy2 = itemgetter(2)(list_to_clean[y]) | ||
if [ix1, ix2] == [iy1, iy2] or [ix1, ix2] == [iy2, iy1]: | ||
for_remove.append(y) | ||
else: | ||
for_remove.append(x) | ||
clean_list = [] | ||
for i, v in enumerate(list_to_clean): | ||
if i not in set(for_remove): | ||
clean_list.append(v) | ||
return clean_list | ||
|
||
if smile in ["Cl-", "F-", "Br-", "I-", "Li+", "Na+"]: | ||
return [] | ||
else: | ||
from rdkit import Chem | ||
|
||
smarts_torsion1 = "[*]~[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]~[*]" | ||
smarts_torsion2 = "[*]~[^2]=[^2]~[*]" # C=C bonds | ||
# smarts_torsion2="[*]~[^1]#[^1]~[*]" # C-C triples bonds, to be fixed | ||
|
||
mol = Chem.MolFromSmiles(smile) | ||
N_atom = mol.GetNumAtoms() | ||
neighbors = [len(a.GetNeighbors()) for a in mol.GetAtoms()] | ||
# make sure that the ending members will be counted | ||
neighbors[0] += 1 | ||
neighbors[-1] += 1 | ||
patn_tor1 = Chem.MolFromSmarts(smarts_torsion1) | ||
torsion1 = cleaner(list(mol.GetSubstructMatches(patn_tor1)), neighbors) | ||
patn_tor2 = Chem.MolFromSmarts(smarts_torsion2) | ||
torsion2 = cleaner(list(mol.GetSubstructMatches(patn_tor2)), neighbors) | ||
tmp = cleaner(torsion1 + torsion2, neighbors) | ||
torsions = [] | ||
for t in tmp: | ||
(i, j, k, l) = t | ||
b = mol.GetBondBetweenAtoms(j, k) | ||
if not b.IsInRing(): | ||
torsions.append(t) | ||
# if len(torsions) > 6: torsions[1] = (4, 7, 10, 15) | ||
return torsions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was copy-pasted from the source at the top of the file, suggested by Chenghao - likely not to be used when we move to DGLConformer.
gflownet/utils/molecule/torsions.py
Outdated
import torch | ||
import networkx as nx | ||
import numpy as np | ||
|
||
from pytorch3d.transforms import axis_angle_to_matrix | ||
|
||
from gflownet.utils.molecule import constants | ||
|
||
|
||
def get_rotation_masks(dgl_graph): | ||
""" | ||
:param dgl_graph: the dgl.Graph object with bidirected edges in the order: [e_1_fwd, e_1_bkw, e_2_fwd, e_2_bkw, ...] | ||
""" | ||
nx_graph = nx.DiGraph(dgl_graph.to_networkx()) | ||
# bonds are indirected edges | ||
bonds = torch.stack(dgl_graph.edges()).numpy().T[::2] | ||
bonds_mask = np.zeros(bonds.shape[0], dtype=bool) | ||
nodes_mask = np.zeros((bonds.shape[0], dgl_graph.num_nodes()), dtype=bool) | ||
rotation_signs = np.zeros(bonds.shape[0], dtype=float) | ||
# fill in masks for bonds | ||
for bond_idx, bond in enumerate(bonds): | ||
modified_graph = nx_graph.to_undirected() | ||
modified_graph.remove_edge(*bond) | ||
if not nx.is_connected(modified_graph): | ||
smallest_component_nodes = sorted( | ||
nx.connected_components(modified_graph), key=len | ||
)[0] | ||
if len(smallest_component_nodes) > 1: | ||
bonds_mask[bond_idx] = True | ||
rotation_signs[bond_idx] = ( | ||
-1 if bond[0] in smallest_component_nodes else 1 | ||
) | ||
affected_nodes = np.array(list(smallest_component_nodes - set(bond))) | ||
nodes_mask[bond_idx, affected_nodes] = np.ones_like( | ||
affected_nodes, dtype=bool | ||
) | ||
|
||
# broadcast bond masks to edges masks | ||
edges_mask = torch.from_numpy(bonds_mask.repeat(2)) | ||
rotation_signs = torch.from_numpy(rotation_signs.repeat(2)) | ||
nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) | ||
return edges_mask, nodes_mask, rotation_signs | ||
|
||
|
||
def apply_rotations(graph, rotations): | ||
""" | ||
Apply rotations (torsion angles updates) | ||
:param dgl_graph: bidirectional dgl.Graph | ||
:param rotations: a sequence of torsion angle updates of length = number of bonds in the molecule. | ||
The order corresponds to the order of edges in the graph, such that action[i] is | ||
an update for the torsion angle corresponding to the edge[2i] | ||
""" | ||
pos = graph.ndata[constants.atom_position_name] | ||
edge_mask = graph.edata[constants.rotatable_edges_mask_name] | ||
node_mask = graph.edata[constants.rotation_affected_nodes_mask_name] | ||
rot_signs = graph.edata[constants.rotation_signs_name] | ||
edges = torch.stack(graph.edges()).T | ||
# TODO check how slow it is and whether it's possible to vectorise this loop | ||
for idx_update, update in enumerate(rotations): | ||
# import ipdb; ipdb.set_trace() | ||
idx_edge = idx_update * 2 | ||
if edge_mask[idx_edge]: | ||
begin_pos = pos[edges[idx_edge][0]] | ||
end_pos = pos[edges[idx_edge][1]] | ||
rot_vector = end_pos - begin_pos | ||
rot_vector = ( | ||
rot_vector | ||
/ torch.linalg.norm(rot_vector) | ||
* update | ||
* rot_signs[idx_edge] | ||
) | ||
rot_matrix = axis_angle_to_matrix(rot_vector) | ||
x = pos[node_mask[idx_edge]] | ||
pos[node_mask[idx_edge]] = ( | ||
torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos | ||
) | ||
graph.ndata[constants.atom_position_name] = pos | ||
return graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently being significantly changed by @AlexandraVolokhova in #201, I'd suggest not reviewing for now.
…tion_fix_mk Formatting changes & typo fixes for TA fix
…xhernandezgarcia/gflownet into torsion_angles_detection_fix
…ezgarcia/gflownet into torsion_angles_detection_fix
…ezgarcia/gflownet into torsion_angles_detection_fix
…xhernandezgarcia/gflownet into torsion_angles_detection_fix
…xhernandezgarcia/gflownet into torsion_angles_detection_fix
…tion_fix Torsion angles detection fix
First round of changes for the sake of conformer experiments.
Works with the MLP policy (https://wandb.ai/michalkoziarski/GFlowNet/reports/Conformers-v5-beta-32---Vmlldzo1MTAwNTA1), contains some initial implementation of graph policy.
We tried to make it as small as possible!