Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Molecule graph environment #139

Open
wants to merge 127 commits into
base: main
Choose a base branch
from
Open

Molecule graph environment #139

wants to merge 127 commits into from

Conversation

michalkoziarski
Copy link
Collaborator

@michalkoziarski michalkoziarski commented Jun 16, 2023

First round of changes for the sake of conformer experiments.

Works with the MLP policy (https://wandb.ai/michalkoziarski/GFlowNet/reports/Conformers-v5-beta-32---Vmlldzo1MTAwNTA1), contains some initial implementation of graph policy.

We tried to make it as small as possible!

AlexandraVolokhova and others added 30 commits February 28, 2023 17:59
…ni_proxy

# Conflicts:
#	gflownet/envs/alaninedipeptide_mixture.py
#	gflownet/utils/molecule/datasets.py
#	gflownet/utils/molecule/dgl_conformer.py
#	gflownet/utils/molecule/old_conformer.py
#	gflownet/utils/molecule/rdkit_conformer.py
#	gflownet/utils/molecule/torsions.py
#	tests/gflownet/utils/molecule/test_torsions.py
Using TorchANI proxy for AlanineDipeptide
Comment on lines 1 to 65
# Taken from https://pyxtal.readthedocs.io/en/latest/_modules/pyxtal/molecule.html.

from operator import itemgetter


def find_rotor_from_smile(smile):
"""
Find the positions of rotatable bonds in the molecule.
"""

def cleaner(list_to_clean, neighbors):
"""
Remove duplicate torsion from a list of atom index tuples.
"""

for_remove = []
for x in reversed(range(len(list_to_clean))):
ix0 = itemgetter(0)(list_to_clean[x])
ix3 = itemgetter(3)(list_to_clean[x])
# for i-j-k-l, we don't want i, l are the ending members
# here C-C-S=O is not a good choice since O is only 1-coordinated
if neighbors[ix0] > 1 and neighbors[ix3] > 1:
for y in reversed(range(x)):
ix1 = itemgetter(1)(list_to_clean[x])
ix2 = itemgetter(2)(list_to_clean[x])
iy1 = itemgetter(1)(list_to_clean[y])
iy2 = itemgetter(2)(list_to_clean[y])
if [ix1, ix2] == [iy1, iy2] or [ix1, ix2] == [iy2, iy1]:
for_remove.append(y)
else:
for_remove.append(x)
clean_list = []
for i, v in enumerate(list_to_clean):
if i not in set(for_remove):
clean_list.append(v)
return clean_list

if smile in ["Cl-", "F-", "Br-", "I-", "Li+", "Na+"]:
return []
else:
from rdkit import Chem

smarts_torsion1 = "[*]~[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]~[*]"
smarts_torsion2 = "[*]~[^2]=[^2]~[*]" # C=C bonds
# smarts_torsion2="[*]~[^1]#[^1]~[*]" # C-C triples bonds, to be fixed

mol = Chem.MolFromSmiles(smile)
N_atom = mol.GetNumAtoms()
neighbors = [len(a.GetNeighbors()) for a in mol.GetAtoms()]
# make sure that the ending members will be counted
neighbors[0] += 1
neighbors[-1] += 1
patn_tor1 = Chem.MolFromSmarts(smarts_torsion1)
torsion1 = cleaner(list(mol.GetSubstructMatches(patn_tor1)), neighbors)
patn_tor2 = Chem.MolFromSmarts(smarts_torsion2)
torsion2 = cleaner(list(mol.GetSubstructMatches(patn_tor2)), neighbors)
tmp = cleaner(torsion1 + torsion2, neighbors)
torsions = []
for t in tmp:
(i, j, k, l) = t
b = mol.GetBondBetweenAtoms(j, k)
if not b.IsInRing():
torsions.append(t)
# if len(torsions) > 6: torsions[1] = (4, 7, 10, 15)
return torsions
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was copy-pasted from the source at the top of the file, suggested by Chenghao - likely not to be used when we move to DGLConformer.

Comment on lines 1 to 78
import torch
import networkx as nx
import numpy as np

from pytorch3d.transforms import axis_angle_to_matrix

from gflownet.utils.molecule import constants


def get_rotation_masks(dgl_graph):
"""
:param dgl_graph: the dgl.Graph object with bidirected edges in the order: [e_1_fwd, e_1_bkw, e_2_fwd, e_2_bkw, ...]
"""
nx_graph = nx.DiGraph(dgl_graph.to_networkx())
# bonds are indirected edges
bonds = torch.stack(dgl_graph.edges()).numpy().T[::2]
bonds_mask = np.zeros(bonds.shape[0], dtype=bool)
nodes_mask = np.zeros((bonds.shape[0], dgl_graph.num_nodes()), dtype=bool)
rotation_signs = np.zeros(bonds.shape[0], dtype=float)
# fill in masks for bonds
for bond_idx, bond in enumerate(bonds):
modified_graph = nx_graph.to_undirected()
modified_graph.remove_edge(*bond)
if not nx.is_connected(modified_graph):
smallest_component_nodes = sorted(
nx.connected_components(modified_graph), key=len
)[0]
if len(smallest_component_nodes) > 1:
bonds_mask[bond_idx] = True
rotation_signs[bond_idx] = (
-1 if bond[0] in smallest_component_nodes else 1
)
affected_nodes = np.array(list(smallest_component_nodes - set(bond)))
nodes_mask[bond_idx, affected_nodes] = np.ones_like(
affected_nodes, dtype=bool
)

# broadcast bond masks to edges masks
edges_mask = torch.from_numpy(bonds_mask.repeat(2))
rotation_signs = torch.from_numpy(rotation_signs.repeat(2))
nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0))
return edges_mask, nodes_mask, rotation_signs


def apply_rotations(graph, rotations):
"""
Apply rotations (torsion angles updates)
:param dgl_graph: bidirectional dgl.Graph
:param rotations: a sequence of torsion angle updates of length = number of bonds in the molecule.
The order corresponds to the order of edges in the graph, such that action[i] is
an update for the torsion angle corresponding to the edge[2i]
"""
pos = graph.ndata[constants.atom_position_name]
edge_mask = graph.edata[constants.rotatable_edges_mask_name]
node_mask = graph.edata[constants.rotation_affected_nodes_mask_name]
rot_signs = graph.edata[constants.rotation_signs_name]
edges = torch.stack(graph.edges()).T
# TODO check how slow it is and whether it's possible to vectorise this loop
for idx_update, update in enumerate(rotations):
# import ipdb; ipdb.set_trace()
idx_edge = idx_update * 2
if edge_mask[idx_edge]:
begin_pos = pos[edges[idx_edge][0]]
end_pos = pos[edges[idx_edge][1]]
rot_vector = end_pos - begin_pos
rot_vector = (
rot_vector
/ torch.linalg.norm(rot_vector)
* update
* rot_signs[idx_edge]
)
rot_matrix = axis_angle_to_matrix(rot_vector)
x = pos[node_mask[idx_edge]]
pos[node_mask[idx_edge]] = (
torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos
)
graph.ndata[constants.atom_position_name] = pos
return graph
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently being significantly changed by @AlexandraVolokhova in #201, I'd suggest not reviewing for now.

@michalkoziarski michalkoziarski changed the title [WIP] Molecule graph environment Molecule graph environment Sep 7, 2023
@michalkoziarski michalkoziarski marked this pull request as ready for review September 7, 2023 20:24
michalkoziarski and others added 23 commits September 7, 2023 18:35
…tion_fix_mk

Formatting changes & typo fixes for TA fix
…ezgarcia/gflownet into torsion_angles_detection_fix
…ezgarcia/gflownet into torsion_angles_detection_fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants