From 130338d904192f4545afa8f7243c86f25fda1206 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 28 Feb 2023 17:59:57 -0500 Subject: [PATCH 001/100] implemented rotation masks --- gflownet/envs/alaninedipeptide.py | 6 +-- gflownet/envs/alaninedipeptide_mixture.py | 6 +-- .../utils/molecule/atom_positions_dataset.py | 18 ------- gflownet/utils/molecule/datasets.py | 51 +++++++++++++++++++ gflownet/utils/molecule/dgl_conformer.py | 13 +++++ .../{conformer.py => old_conformer.py} | 4 +- .../{conformer_base.py => rdkit_conformer.py} | 8 +-- gflownet/utils/molecule/torsions.py | 27 ++++++++++ .../gflownet/utils/molecule/test_torsions.py | 33 ++++++++++++ 9 files changed, 136 insertions(+), 30 deletions(-) delete mode 100644 gflownet/utils/molecule/atom_positions_dataset.py create mode 100644 gflownet/utils/molecule/datasets.py create mode 100644 gflownet/utils/molecule/dgl_conformer.py rename gflownet/utils/molecule/{conformer.py => old_conformer.py} (98%) rename gflownet/utils/molecule/{conformer_base.py => rdkit_conformer.py} (98%) create mode 100644 gflownet/utils/molecule/torsions.py create mode 100644 tests/gflownet/utils/molecule/test_torsions.py diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index 023e2fd6a..7dcf77af0 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -7,8 +7,8 @@ from gflownet.envs.ctorus import ContinuousTorus from gflownet.utils.molecule import constants -from gflownet.utils.molecule.atom_positions_dataset import AtomPositionsDataset -from gflownet.utils.molecule.conformer_base import ConformerBase +from gflownet.utils.molecule.datasets import AtomPositionsDataset +from gflownet.utils.molecule.rdkit_conformer import RDKitConformer class AlanineDipeptide(ContinuousTorus): @@ -37,7 +37,7 @@ def __init__( ): self.atom_positions_dataset = AtomPositionsDataset(path_to_dataset, url_to_dataset) atom_positions = self.atom_positions_dataset.sample() - self.conformer = ConformerBase( + self.conformer = RDKitConformer( atom_positions, constants.ad_smiles, constants.ad_free_tas ) n_dim = len(self.conformer.freely_rotatable_tas) diff --git a/gflownet/envs/alaninedipeptide_mixture.py b/gflownet/envs/alaninedipeptide_mixture.py index 5139725ab..c95458092 100644 --- a/gflownet/envs/alaninedipeptide_mixture.py +++ b/gflownet/envs/alaninedipeptide_mixture.py @@ -8,8 +8,8 @@ from gflownet.envs.ctorusmixture import ContinuousTorusMixture from gflownet.utils.molecule import constants -from gflownet.utils.molecule.atom_positions_dataset import AtomPositionsDataset -from gflownet.utils.molecule.conformer_base import ConformerBase +from gflownet.utils.molecule.datasets import AtomPositionsDataset +from gflownet.utils.molecule.rdkit_conformer import RDKitConformer class AlanineDipeptideMixture(ContinuousTorusMixture): @@ -40,7 +40,7 @@ def __init__( self.atom_positions_dataset = AtomPositionsDataset(path_to_dataset, url_to_dataset) atom_positions = self.atom_positions_dataset.sample() - self.conformer = ConformerBase( + self.conformer = RDKitConformer( atom_positions, constants.ad_smiles, constants.ad_free_tas ) n_dim = len(self.conformer.freely_rotatable_tas) diff --git a/gflownet/utils/molecule/atom_positions_dataset.py b/gflownet/utils/molecule/atom_positions_dataset.py deleted file mode 100644 index 0ed1eb171..000000000 --- a/gflownet/utils/molecule/atom_positions_dataset.py +++ /dev/null @@ -1,18 +0,0 @@ -import numpy as np - -from gflownet.utils.common import download_file_if_not_exists - -class AtomPositionsDataset: - def __init__(self, path_to_data, url_to_data): - path_to_data = download_file_if_not_exists(path_to_data, url_to_data) - self.positions = np.load(path_to_data) - - def __getitem__(self, i): - return self.positions[i] - - def __len__(self): - return self.positions.shape[0] - - def sample(self, size=None): - idx = np.random.randint(0, len(self), size=size) - return self.positions[idx] diff --git a/gflownet/utils/molecule/datasets.py b/gflownet/utils/molecule/datasets.py new file mode 100644 index 000000000..eedfc807b --- /dev/null +++ b/gflownet/utils/molecule/datasets.py @@ -0,0 +1,51 @@ +import dgl +import numpy as np + +from gflownet.utils.common import download_file_if_not_exists +from gflownet.utils.molecule import constants +from gflownet.utils.molecule.dgl_conformer import DGLConformer + +class AtomPositionsDataset: + def __init__(self, path_to_data, url_to_data): + path_to_data = download_file_if_not_exists(path_to_data, url_to_data) + self.positions = np.load(path_to_data) + + def __getitem__(self, i): + return self.positions[i] + + def __len__(self): + return self.positions.shape[0] + + def sample(self, size=None): + idx = np.random.randint(0, len(self), size=size) + return self.positions[idx] + +class ConformersDataset: + def __init__(self, path_to_data, url_to_data): + # TODO create a new dataset if path_to_data or url_to_data doesn't exist + path_to_data = download_file_if_not_exists(path_to_data, url_to_data) + with open(path_to_data, 'rb') as inp: + self.conformers = pickle.load(inp) + + def get_conformer(self): + """ + Returns dgl graph with features stored in the dataset: + - ndata: + - atom features + - atomic numbers + - atom position + - edata: + - edge features + - rotatable bonds mask + """ + smiles = np.random.choice(self.conformers.keys()) + edges = self.conformers[smiles]['edges'] + graph = dgl.graph(edges) + graph.ndata[constants.atom_feature_name] = self.conformers[smiles][constants.atom_feature_name] + graph.ndata[constants.atomic_numbers_name] = self.conformers[smiles][constants.atomic_numbers_name] + graph.edata[constants.edge_feature_name] = self.conformers[smiles][constants.edge_feature_name] + graph.edata[constants.rotatable_bonds_mask] = self.conformers[smiles][constants.rotatable_bonds_mask] + conf_idx = np.random.randint(0, self.conformers[smiles][constants.atom_position_name].shape[0]) + graph.ndata[constants.atom_position_name] = self.conformers[smiles][constants.atom_position_name][conf_idx] + conformer = DGLConformer(graph) + return smiles, conformer \ No newline at end of file diff --git a/gflownet/utils/molecule/dgl_conformer.py b/gflownet/utils/molecule/dgl_conformer.py new file mode 100644 index 000000000..acfab419b --- /dev/null +++ b/gflownet/utils/molecule/dgl_conformer.py @@ -0,0 +1,13 @@ +import torch + +class DGLConformer: + def __init__(self, dgl_graph): + self.graph = dgl_graph + + def increment_torsion_angles(self, increments): + raise NotImplementedError + + def randomise_torsion_angles(self): + raise NotImplementedError + + \ No newline at end of file diff --git a/gflownet/utils/molecule/conformer.py b/gflownet/utils/molecule/old_conformer.py similarity index 98% rename from gflownet/utils/molecule/conformer.py rename to gflownet/utils/molecule/old_conformer.py index 1fffc9fe8..de3bc04a2 100644 --- a/gflownet/utils/molecule/conformer.py +++ b/gflownet/utils/molecule/old_conformer.py @@ -11,10 +11,10 @@ from gflownet.utils.molecule import constants from gflownet.utils.molecule.featurizer import MolDGLFeaturizer -from gflownet.utils.molecule.conformer_base import ConformerBase +from gflownet.utils.molecule.rdkit_conformer import RDKitConformer -class Conformer(ConformerBase): +class Conformer(RDKitConformer): def __init__(self, atom_positions, smiles, atom_types, freely_rotatable_tas=None): """ :param atom_positions: numpy.ndarray of shape [num_atoms, 3] of dtype float64 diff --git a/gflownet/utils/molecule/conformer_base.py b/gflownet/utils/molecule/rdkit_conformer.py similarity index 98% rename from gflownet/utils/molecule/conformer_base.py rename to gflownet/utils/molecule/rdkit_conformer.py index 686fe119f..adeffb16d 100644 --- a/gflownet/utils/molecule/conformer_base.py +++ b/gflownet/utils/molecule/rdkit_conformer.py @@ -37,15 +37,15 @@ def get_dummy_ad_atom_positions(): return rconf.GetPositions() -def get_dummy_ad_conf_base(): +def get_dummy_ad_rdkconf(): pos = get_dummy_ad_atom_positions() - conf = ConformerBase( + conf = RDKitConformer( pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas ) return conf -class ConformerBase: +class RDKitConformer: def __init__(self, atom_positions, smiles, freely_rotatable_tas=None): """ :param atom_positions: numpy.ndarray of shape [num_atoms, 3] of dtype float64 @@ -147,7 +147,7 @@ def increment_torsion_angle(self, torsion_angle, increment): test_pos = rconf.GetPositions() initial_tas = get_all_torsion_angles(rmol, rconf) - conf = ConformerBase( + conf = RDKitConformer( test_pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas ) # check torsion angles randomisation diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py new file mode 100644 index 000000000..e8d52e682 --- /dev/null +++ b/gflownet/utils/molecule/torsions.py @@ -0,0 +1,27 @@ +import torch +import networkx as nx +import numpy as np + +def get_rotation_masks(dgl_graph): + """ + :param dgl_graph: the dgl.Graph object with bidirected edges in the order: [e_1_fwd, e_1_bkw, e_2_fwd, e_2_bkw, ...] + """ + nx_graph = nx.DiGraph(dgl_graph.to_networkx()) + # bonds are indirected edges + bonds = torch.stack(dgl_graph.edges()).numpy().T[::2] + bonds_mask = np.zeros(bonds.shape[0], dtype=bool) + nodes_mask = np.zeros((bonds.shape[0], dgl_graph.num_nodes()), dtype=bool) + # fill in masks for bonds + for bond_idx, bond in enumerate(bonds): + modified_graph = nx_graph.to_undirected() + modified_graph.remove_edge(*bond) + if not nx.is_connected(modified_graph): + smallest_component_nodes = sorted(nx.connected_components(modified_graph), key=len)[0] + if len(smallest_component_nodes) > 1: + bonds_mask[bond_idx] = True + affected_nodes = np.array(list(smallest_component_nodes - set(bond))) + nodes_mask[bond_idx, affected_nodes] = np.ones_like(affected_nodes, dtype=bool) + # broadcast bond masks to edges masks + edges_mask = torch.from_numpy(bonds_mask.repeat(2)) + nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) + return edges_mask, nodes_mask diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py new file mode 100644 index 000000000..f1e3e65bf --- /dev/null +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -0,0 +1,33 @@ +import pytest +import torch +import dgl + +from gflownet.utils.molecule.torsions import get_rotation_masks + +def test_four_nodes_chain(): + graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) + edges_mask, nodes_mask = get_rotation_masks(graph) + correct_edges_mask = torch.tensor([False, False, True, True, False, False]) + correct_nodes_mask = torch.tensor([[False, False, False, False], + [False, False, False, False], + [ True, False, False, False], + [ True, False, False, False], + [False, False, False, False], + [False, False, False, False]]) + assert torch.all(edges_mask == correct_edges_mask) + assert torch.all(nodes_mask == correct_nodes_mask) + +def test_choose_smallest_component(): + graph = dgl.graph(([0, 2, 1, 2, 2, 3, 3, 4], [2, 0, 2, 1, 3, 2, 4, 3])) + edges_mask, nodes_mask = get_rotation_masks(graph) + correct_edges_mask = torch.tensor([False, False, False, False, True, True, False, False]) + correct_nodes_mask = torch.tensor([[False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, True], + [False, False, False, False, True], + [False, False, False, False, False], + [False, False, False, False, False]]) + assert torch.all(edges_mask == correct_edges_mask) + assert torch.all(nodes_mask == correct_nodes_mask) \ No newline at end of file From fe9d5acf7a554a41f8592848d9548b7dc0120e03 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 28 Feb 2023 18:11:57 -0500 Subject: [PATCH 002/100] add masks to featuraser --- gflownet/utils/molecule/constants.py | 2 ++ gflownet/utils/molecule/featurizer.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/gflownet/utils/molecule/constants.py b/gflownet/utils/molecule/constants.py index d2786b6cb..e74c3616c 100644 --- a/gflownet/utils/molecule/constants.py +++ b/gflownet/utils/molecule/constants.py @@ -6,6 +6,8 @@ edge_feature_name = "edge_features" step_feature_name = "step" atomic_numbers_name = "atomic_numbers" +rotatable_edges_mask_name = "rotatable_edges" +rotation_affected_nodes_mask_name = "rotation_affected_nodes" # Options for atoms featurization ad_atom_types = ("H", "C", "N", "O") diff --git a/gflownet/utils/molecule/featurizer.py b/gflownet/utils/molecule/featurizer.py index f96a82fdd..caa480a8d 100644 --- a/gflownet/utils/molecule/featurizer.py +++ b/gflownet/utils/molecule/featurizer.py @@ -2,6 +2,7 @@ import torch from gflownet.utils.molecule import constants +from gflownet.utils.molecule.torsions import get_rotation_masks class MolDGLFeaturizer: @@ -107,6 +108,9 @@ def mol2dgl(self, mol): graph.ndata[constants.atom_feature_name] = node_features graph.ndata[constants.atomic_numbers_name] = self.get_atomic_numbers(mol) graph.edata[constants.edge_feature_name] = edge_features + edges_mask, nodes_mask = get_rotation_masks(graph) + graph.edata[constants.rotatable_edges_mask_name] = edges_mask + graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask return graph @@ -123,5 +127,6 @@ def mol2dgl(self, mol): print("node features shape:", graph.ndata[constants.atom_feature_name].shape) print("edge features shape:", graph.edata[constants.edge_feature_name].shape) print("edges:", *graph.edges(), sep="\n") + print(graph.edata[constants.rotatable_edges_mask_name]) assert graph.ndata[constants.atom_feature_name].shape[0] == mol.GetNumAtoms() assert graph.edata[constants.edge_feature_name].shape[0] == 2 * mol.GetNumBonds() From 6c9de563050d0a19fa7d1393cb4d3754b339ca1c Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 28 Feb 2023 18:55:38 -0500 Subject: [PATCH 003/100] implemented apply rotations --- gflownet/utils/molecule/dgl_conformer.py | 8 +++++- gflownet/utils/molecule/torsions.py | 33 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/molecule/dgl_conformer.py b/gflownet/utils/molecule/dgl_conformer.py index acfab419b..ce08c6511 100644 --- a/gflownet/utils/molecule/dgl_conformer.py +++ b/gflownet/utils/molecule/dgl_conformer.py @@ -4,7 +4,13 @@ class DGLConformer: def __init__(self, dgl_graph): self.graph = dgl_graph - def increment_torsion_angles(self, increments): + def apply_rotations(self, rotations): + """ + Apply rotations (torsion angles updates) + :param rotations: a sequence of torsion angle updates of length = number of bonds in the molecule. + The order corresponds to the order of edges in self.graph, such that action[i] is + an update for the torsion angle corresponding to the edge[2i] + """ raise NotImplementedError def randomise_torsion_angles(self): diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py index e8d52e682..2f1927405 100644 --- a/gflownet/utils/molecule/torsions.py +++ b/gflownet/utils/molecule/torsions.py @@ -2,6 +2,10 @@ import networkx as nx import numpy as np +from pytorch3d.transforms import axis_angle_to_matrix + +from gflownet.utils.molecule import constants + def get_rotation_masks(dgl_graph): """ :param dgl_graph: the dgl.Graph object with bidirected edges in the order: [e_1_fwd, e_1_bkw, e_2_fwd, e_2_bkw, ...] @@ -25,3 +29,32 @@ def get_rotation_masks(dgl_graph): edges_mask = torch.from_numpy(bonds_mask.repeat(2)) nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) return edges_mask, nodes_mask + +def apply_rotations(dgl_graph, rotations): + """ + Apply rotations (torsion angles updates) + :param dgl_graph: bidirectional dgl.Graph + :param rotations: a sequence of torsion angle updates of length = number of bonds in the molecule. + The order corresponds to the order of edges in the graph, such that action[i] is + an update for the torsion angle corresponding to the edge[2i] + """ + pos = graph.ndata[constants.atom_position_name] + edge_mask = graph.edata[constants.rotatable_edges_mask_name] + node_mask = graph.edata[constants.rotation_affected_nodes_mask_name] + edges = torch.stack(graph.edges()).T + # TODO check how slow it is and whether it's possible to vectorise this loop + for idx_update, update in enumerate(rotations): + idx_edge = idx_update * 2 + if edge_mask[idx_edge]: + begin_pos = pos[edges[idx_edge][0]] + end_pos = pos[edges[idx_edge][1]] + rot_vector = end_pos - begin_pos + rot_vector = rot_vector / torch.linalg.norm(rot_vector) * update + rot_matrix = axis_angle_to_matrix(rot_vector) + x = pos[node_mask[idx_edge]] + pos[node_mask[idx_edge]] = torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos + dgl_graph.ndata[constants.atom_position_name] = pos + return dgl_graph + + + From 99fb7c78adf214e254cd136717c23394387ac475 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Wed, 1 Mar 2023 19:20:56 -0500 Subject: [PATCH 004/100] add simple tests for apply rotations --- gflownet/utils/molecule/torsions.py | 36 ++++- .../gflownet/utils/molecule/test_torsions.py | 140 +++++++++++++++++- 2 files changed, 171 insertions(+), 5 deletions(-) diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py index 2f1927405..06551cd91 100644 --- a/gflownet/utils/molecule/torsions.py +++ b/gflownet/utils/molecule/torsions.py @@ -30,7 +30,7 @@ def get_rotation_masks(dgl_graph): nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) return edges_mask, nodes_mask -def apply_rotations(dgl_graph, rotations): +def apply_rotations(graph, rotations): """ Apply rotations (torsion angles updates) :param dgl_graph: bidirectional dgl.Graph @@ -53,8 +53,38 @@ def apply_rotations(dgl_graph, rotations): rot_matrix = axis_angle_to_matrix(rot_vector) x = pos[node_mask[idx_edge]] pos[node_mask[idx_edge]] = torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos - dgl_graph.ndata[constants.atom_position_name] = pos - return dgl_graph + graph.ndata[constants.atom_position_name] = pos + return graph +if __name__ == '__main__': + from rdkit import Chem + from rdkit.Chem import AllChem + from rdkit.Chem import rdMolTransforms + from rdkit.Chem import TorsionFingerprints + from rdkit.Geometry.rdGeometry import Point3D + from gflownet.utils.molecule.featurizer import MolDGLFeaturizer + from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values + + mol = Chem.MolFromSmiles(constants.ad_smiles) + mol = Chem.AddHs(mol) + AllChem.EmbedMolecule(mol) + rconf = mol.GetConformer() + start_pos = rconf.GetPositions() + + featurizer = MolDGLFeaturizer(constants.ad_atom_types) + graph = featurizer.mol2dgl(mol) + graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) + bonds = torch.stack(graph.edges())[:,::2] + print(bonds) + print(graph.edata[constants.rotatable_edges_mask_name][::2]) + print(bonds[:, graph.edata[constants.rotatable_edges_mask_name][::2]]) + torsion_angles = [(10, 0, 1, 6)] + print(get_torsion_angles_values(rconf, torsion_angles)) + torsion_angles = [(11, 0, 1, 6)] + print(get_torsion_angles_values(rconf, torsion_angles)) + torsion_angles = [(6, 1, 0, 10)] + print(get_torsion_angles_values(rconf, torsion_angles)) + torsion_angles = [(6, 0, 1, 10)] + print(get_torsion_angles_values(rconf, torsion_angles)) diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index f1e3e65bf..e8a1795ce 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -2,7 +2,8 @@ import torch import dgl -from gflownet.utils.molecule.torsions import get_rotation_masks +from gflownet.utils.molecule.torsions import get_rotation_masks, apply_rotations +from gflownet.utils.molecule import constants def test_four_nodes_chain(): graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) @@ -30,4 +31,139 @@ def test_choose_smallest_component(): [False, False, False, False, False], [False, False, False, False, False]]) assert torch.all(edges_mask == correct_edges_mask) - assert torch.all(nodes_mask == correct_nodes_mask) \ No newline at end of file + assert torch.all(nodes_mask == correct_nodes_mask) + +@pytest.mark.parametrize( + "angle, exp_result", + [ + ( + torch.pi / 2, + torch.tensor( + [[1., 0., 1.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi, + torch.tensor( + [[2., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi * 3 / 2, + torch.tensor( + [[1., 0., -1.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi * 2, + torch.tensor( + [[0., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + + ] + +) +def test_apply_rotations_simple(angle, exp_result): + graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) + graph.ndata[constants.atom_position_name] = torch.tensor([ + [0., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.] + ]) + edges_mask, nodes_mask = get_rotation_masks(graph) + graph.edata[constants.rotatable_edges_mask_name] = edges_mask + graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + rotations = torch.tensor([0., angle, 0.]) + result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] + assert torch.allclose(result, exp_result, atol=1e-6) + + +@pytest.mark.parametrize( + "angle, exp_result", + [ + ( + torch.pi / 2, + torch.tensor( + [[1., 0., 1.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi, + torch.tensor( + [[2., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi * 3 / 2, + torch.tensor( + [[1., 0., -1.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi * 2, + torch.tensor( + [[0., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + + ] + +) +def test_apply_rotations_ignore_nonrotatable(angle, exp_result): + graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) + graph.ndata[constants.atom_position_name] = torch.tensor([ + [0., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.] + ]) + edges_mask, nodes_mask = get_rotation_masks(graph) + graph.edata[constants.rotatable_edges_mask_name] = edges_mask + graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + rotations = torch.tensor([2., angle, -1.]) + result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] + assert torch.allclose(result, exp_result, atol=1e-6) + +# def test_apply_rotation_alanine_dipeptide(): +# from rdkit import Chem + +# mol = Chem.MolFromSmiles(constants.ad_smiles) +# mol = Chem.AddHs(mol) +# AllChem.EmbedMolecule(rmol) +# rconf = rmol.GetConformer() +# start_pos = rconf.GetPositions() + +# featurizer = MolDGLFeaturizer(constants.ad_atom_types) + +# graph = featurizer.mol2dgl(mol) +# graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) + + +# rmol = Chem.MolFromSmiles(constants.ad_smiles) +# rmol = Chem.AddHs(rmol) +# AllChem.EmbedMolecule(rmol) +# rconf = rmol.GetConformer() +# test_pos = rconf.GetPositions() +# initial_tas = get_all_torsion_angles(rmol, rconf) + +# conf = RDKitConformer( +# test_pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas +# ) \ No newline at end of file From d1713cb018ff7425ae5fb5a6527a64fb81c8a032 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 2 Mar 2023 17:24:59 -0500 Subject: [PATCH 005/100] fix bug in torsios, add test with AD --- gflownet/utils/molecule/constants.py | 1 + gflownet/utils/molecule/featurizer.py | 3 +- gflownet/utils/molecule/torsions.py | 45 ++------- .../gflownet/utils/molecule/test_torsions.py | 92 ++++++++++++------- 4 files changed, 71 insertions(+), 70 deletions(-) diff --git a/gflownet/utils/molecule/constants.py b/gflownet/utils/molecule/constants.py index e74c3616c..ce963c3a1 100644 --- a/gflownet/utils/molecule/constants.py +++ b/gflownet/utils/molecule/constants.py @@ -8,6 +8,7 @@ atomic_numbers_name = "atomic_numbers" rotatable_edges_mask_name = "rotatable_edges" rotation_affected_nodes_mask_name = "rotation_affected_nodes" +rotation_signs_name = "rotation_signs" # Options for atoms featurization ad_atom_types = ("H", "C", "N", "O") diff --git a/gflownet/utils/molecule/featurizer.py b/gflownet/utils/molecule/featurizer.py index caa480a8d..51be811db 100644 --- a/gflownet/utils/molecule/featurizer.py +++ b/gflownet/utils/molecule/featurizer.py @@ -108,9 +108,10 @@ def mol2dgl(self, mol): graph.ndata[constants.atom_feature_name] = node_features graph.ndata[constants.atomic_numbers_name] = self.get_atomic_numbers(mol) graph.edata[constants.edge_feature_name] = edge_features - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) graph.edata[constants.rotatable_edges_mask_name] = edges_mask graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + graph.edata[constants.rotation_signs_name] = rotation_signs return graph diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py index 06551cd91..b210c0e2e 100644 --- a/gflownet/utils/molecule/torsions.py +++ b/gflownet/utils/molecule/torsions.py @@ -15,6 +15,7 @@ def get_rotation_masks(dgl_graph): bonds = torch.stack(dgl_graph.edges()).numpy().T[::2] bonds_mask = np.zeros(bonds.shape[0], dtype=bool) nodes_mask = np.zeros((bonds.shape[0], dgl_graph.num_nodes()), dtype=bool) + rotation_signs = np.zeros(bonds.shape[0], dtype=float) # fill in masks for bonds for bond_idx, bond in enumerate(bonds): modified_graph = nx_graph.to_undirected() @@ -23,12 +24,15 @@ def get_rotation_masks(dgl_graph): smallest_component_nodes = sorted(nx.connected_components(modified_graph), key=len)[0] if len(smallest_component_nodes) > 1: bonds_mask[bond_idx] = True + rotation_signs[bond_idx] = -1 if bond[0] in smallest_component_nodes else 1 affected_nodes = np.array(list(smallest_component_nodes - set(bond))) nodes_mask[bond_idx, affected_nodes] = np.ones_like(affected_nodes, dtype=bool) + # broadcast bond masks to edges masks edges_mask = torch.from_numpy(bonds_mask.repeat(2)) + rotation_signs = torch.from_numpy(rotation_signs.repeat(2)) nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) - return edges_mask, nodes_mask + return edges_mask, nodes_mask, rotation_signs def apply_rotations(graph, rotations): """ @@ -41,50 +45,19 @@ def apply_rotations(graph, rotations): pos = graph.ndata[constants.atom_position_name] edge_mask = graph.edata[constants.rotatable_edges_mask_name] node_mask = graph.edata[constants.rotation_affected_nodes_mask_name] + rot_signs = graph.edata[constants.rotation_signs_name] edges = torch.stack(graph.edges()).T # TODO check how slow it is and whether it's possible to vectorise this loop for idx_update, update in enumerate(rotations): + # import ipdb; ipdb.set_trace() idx_edge = idx_update * 2 if edge_mask[idx_edge]: begin_pos = pos[edges[idx_edge][0]] end_pos = pos[edges[idx_edge][1]] rot_vector = end_pos - begin_pos - rot_vector = rot_vector / torch.linalg.norm(rot_vector) * update + rot_vector = rot_vector / torch.linalg.norm(rot_vector) * update * rot_signs[idx_edge] rot_matrix = axis_angle_to_matrix(rot_vector) x = pos[node_mask[idx_edge]] pos[node_mask[idx_edge]] = torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos graph.ndata[constants.atom_position_name] = pos - return graph - - - -if __name__ == '__main__': - from rdkit import Chem - from rdkit.Chem import AllChem - from rdkit.Chem import rdMolTransforms - from rdkit.Chem import TorsionFingerprints - from rdkit.Geometry.rdGeometry import Point3D - from gflownet.utils.molecule.featurizer import MolDGLFeaturizer - from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values - - mol = Chem.MolFromSmiles(constants.ad_smiles) - mol = Chem.AddHs(mol) - AllChem.EmbedMolecule(mol) - rconf = mol.GetConformer() - start_pos = rconf.GetPositions() - - featurizer = MolDGLFeaturizer(constants.ad_atom_types) - graph = featurizer.mol2dgl(mol) - graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) - bonds = torch.stack(graph.edges())[:,::2] - print(bonds) - print(graph.edata[constants.rotatable_edges_mask_name][::2]) - print(bonds[:, graph.edata[constants.rotatable_edges_mask_name][::2]]) - torsion_angles = [(10, 0, 1, 6)] - print(get_torsion_angles_values(rconf, torsion_angles)) - torsion_angles = [(11, 0, 1, 6)] - print(get_torsion_angles_values(rconf, torsion_angles)) - torsion_angles = [(6, 1, 0, 10)] - print(get_torsion_angles_values(rconf, torsion_angles)) - torsion_angles = [(6, 0, 1, 10)] - print(get_torsion_angles_values(rconf, torsion_angles)) + return graph \ No newline at end of file diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index e8a1795ce..b5ba74997 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -2,12 +2,18 @@ import torch import dgl +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Geometry.rdGeometry import Point3D + from gflownet.utils.molecule.torsions import get_rotation_masks, apply_rotations from gflownet.utils.molecule import constants +from gflownet.utils.molecule.featurizer import MolDGLFeaturizer +from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values def test_four_nodes_chain(): graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) correct_edges_mask = torch.tensor([False, False, True, True, False, False]) correct_nodes_mask = torch.tensor([[False, False, False, False], [False, False, False, False], @@ -15,12 +21,14 @@ def test_four_nodes_chain(): [ True, False, False, False], [False, False, False, False], [False, False, False, False]]) + correct_rotation_signs = torch.tensor([ 0., 0., -1., -1., 0., 0.]) assert torch.all(edges_mask == correct_edges_mask) assert torch.all(nodes_mask == correct_nodes_mask) + assert torch.all(rotation_signs == correct_rotation_signs) def test_choose_smallest_component(): graph = dgl.graph(([0, 2, 1, 2, 2, 3, 3, 4], [2, 0, 2, 1, 3, 2, 4, 3])) - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) correct_edges_mask = torch.tensor([False, False, False, False, True, True, False, False]) correct_nodes_mask = torch.tensor([[False, False, False, False, False], [False, False, False, False, False], @@ -30,8 +38,10 @@ def test_choose_smallest_component(): [False, False, False, False, True], [False, False, False, False, False], [False, False, False, False, False]]) + correct_rotation_signs = torch.tensor([0., 0., 0., 0., 1., 1., 0., 0.]) assert torch.all(edges_mask == correct_edges_mask) assert torch.all(nodes_mask == correct_nodes_mask) + assert torch.all(rotation_signs == correct_rotation_signs) @pytest.mark.parametrize( "angle, exp_result", @@ -39,7 +49,7 @@ def test_choose_smallest_component(): ( torch.pi / 2, torch.tensor( - [[1., 0., 1.], + [[1., 0., -1.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]]) @@ -55,7 +65,7 @@ def test_choose_smallest_component(): ( torch.pi * 3 / 2, torch.tensor( - [[1., 0., -1.], + [[1., 0., 1.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]]) @@ -80,9 +90,10 @@ def test_apply_rotations_simple(angle, exp_result): [1., 1., 0.], [2., 1., 0.] ]) - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) graph.edata[constants.rotatable_edges_mask_name] = edges_mask graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + graph.edata[constants.rotation_signs_name] = rotation_signs rotations = torch.tensor([0., angle, 0.]) result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] assert torch.allclose(result, exp_result, atol=1e-6) @@ -94,7 +105,7 @@ def test_apply_rotations_simple(angle, exp_result): ( torch.pi / 2, torch.tensor( - [[1., 0., 1.], + [[1., 0., -1.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]]) @@ -110,7 +121,7 @@ def test_apply_rotations_simple(angle, exp_result): ( torch.pi * 3 / 2, torch.tensor( - [[1., 0., -1.], + [[1., 0., 1.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]]) @@ -135,35 +146,50 @@ def test_apply_rotations_ignore_nonrotatable(angle, exp_result): [1., 1., 0.], [2., 1., 0.] ]) - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) graph.edata[constants.rotatable_edges_mask_name] = edges_mask graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + graph.edata[constants.rotation_signs_name] = rotation_signs rotations = torch.tensor([2., angle, -1.]) result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] assert torch.allclose(result, exp_result, atol=1e-6) -# def test_apply_rotation_alanine_dipeptide(): -# from rdkit import Chem - -# mol = Chem.MolFromSmiles(constants.ad_smiles) -# mol = Chem.AddHs(mol) -# AllChem.EmbedMolecule(rmol) -# rconf = rmol.GetConformer() -# start_pos = rconf.GetPositions() - -# featurizer = MolDGLFeaturizer(constants.ad_atom_types) - -# graph = featurizer.mol2dgl(mol) -# graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) - - -# rmol = Chem.MolFromSmiles(constants.ad_smiles) -# rmol = Chem.AddHs(rmol) -# AllChem.EmbedMolecule(rmol) -# rconf = rmol.GetConformer() -# test_pos = rconf.GetPositions() -# initial_tas = get_all_torsion_angles(rmol, rconf) - -# conf = RDKitConformer( -# test_pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas -# ) \ No newline at end of file +def stress_test_apply_rotation_alanine_dipeptide(): + from rdkit import Chem + from rdkit.Chem import AllChem + from rdkit.Geometry.rdGeometry import Point3D + from gflownet.utils.molecule.featurizer import MolDGLFeaturizer + from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values + + mol = Chem.MolFromSmiles(constants.ad_smiles) + mol = Chem.AddHs(mol) + AllChem.EmbedMolecule(mol) + rconf = mol.GetConformer() + start_pos = rconf.GetPositions() + + featurizer = MolDGLFeaturizer(constants.ad_atom_types) + graph = featurizer.mol2dgl(mol) + graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) + + torsion_angles = [ + (10, 0, 1, 6), + (0, 1, 2, 3), + (1, 2, 4, 14), + (2, 4, 5, 15), + (0, 1, 6, 7), + (18, 6, 7, 8), + (8, 7, 9, 19) + ] + n_edges = graph.edges()[0].shape[-1] + for _ in range (100): + ta_initial_values = torch.tensor(get_torsion_angles_values(rconf, torsion_angles)) + + rotations = torch.rand(n_edges // 2) * torch.pi * 2 + graph = apply_rotations(graph, rotations) + new_pos = graph.ndata[constants.atom_position_name].numpy() + for idx, pos in enumerate(new_pos): + rconf.SetAtomPosition(idx, Point3D(*pos)) + ta_updated_values = torch.tensor(get_torsion_angles_values(rconf, torsion_angles)) + valid_rotations = rotations[graph.edata[constants.rotatable_edges_mask_name][::2]] + diff = (ta_updated_values - ta_initial_values - valid_rotations) % (2*torch.pi) + assert torch.logical_or(torch.isclose(diff, torch.zeros_like(diff), atol=1e-6), torch.isclose(diff, torch.ones_like(diff)*2*torch.pi, atol=1e-5)).all() \ No newline at end of file From 217099cc93a626bf0d3d6de0b077dfdb555a5f75 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 11 Apr 2023 18:07:43 -0400 Subject: [PATCH 006/100] add a comment to ConformerDataset --- gflownet/utils/molecule/datasets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/molecule/datasets.py b/gflownet/utils/molecule/datasets.py index eedfc807b..cfb43dba4 100644 --- a/gflownet/utils/molecule/datasets.py +++ b/gflownet/utils/molecule/datasets.py @@ -38,6 +38,7 @@ def get_conformer(self): - edge features - rotatable bonds mask """ + # TODO make it work if there're several conformers for a single molecule smiles = np.random.choice(self.conformers.keys()) edges = self.conformers[smiles]['edges'] graph = dgl.graph(edges) @@ -48,4 +49,4 @@ def get_conformer(self): conf_idx = np.random.randint(0, self.conformers[smiles][constants.atom_position_name].shape[0]) graph.ndata[constants.atom_position_name] = self.conformers[smiles][constants.atom_position_name][conf_idx] conformer = DGLConformer(graph) - return smiles, conformer \ No newline at end of file + return smiles, conformer From b000332d42f2d1c05c65e411a2ccbd2b27d773fc Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 28 Feb 2023 17:59:57 -0500 Subject: [PATCH 007/100] implemented rotation masks --- gflownet/envs/alaninedipeptide.py | 6 +-- .../utils/molecule/atom_positions_dataset.py | 19 ------- gflownet/utils/molecule/datasets.py | 51 +++++++++++++++++++ gflownet/utils/molecule/dgl_conformer.py | 13 +++++ .../{conformer.py => old_conformer.py} | 3 +- .../{conformer_base.py => rdkit_conformer.py} | 10 ++-- gflownet/utils/molecule/torsions.py | 27 ++++++++++ .../gflownet/utils/molecule/test_torsions.py | 33 ++++++++++++ 8 files changed, 135 insertions(+), 27 deletions(-) delete mode 100644 gflownet/utils/molecule/atom_positions_dataset.py create mode 100644 gflownet/utils/molecule/datasets.py create mode 100644 gflownet/utils/molecule/dgl_conformer.py rename gflownet/utils/molecule/{conformer.py => old_conformer.py} (98%) rename gflownet/utils/molecule/{conformer_base.py => rdkit_conformer.py} (96%) create mode 100644 gflownet/utils/molecule/torsions.py create mode 100644 tests/gflownet/utils/molecule/test_torsions.py diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index 76b725e3b..030c4a32e 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -8,8 +8,8 @@ from gflownet.envs.ctorus import ContinuousTorus from gflownet.utils.molecule import constants -from gflownet.utils.molecule.atom_positions_dataset import AtomPositionsDataset -from gflownet.utils.molecule.conformer_base import ConformerBase +from gflownet.utils.molecule.datasets import AtomPositionsDataset +from gflownet.utils.molecule.rdkit_conformer import RDKitConformer class AlanineDipeptide(ContinuousTorus): @@ -26,7 +26,7 @@ def __init__( path_to_dataset, url_to_dataset ) atom_positions = self.atom_positions_dataset.sample() - self.conformer = ConformerBase( + self.conformer = RDKitConformer( atom_positions, constants.ad_smiles, constants.ad_free_tas ) n_dim = len(self.conformer.freely_rotatable_tas) diff --git a/gflownet/utils/molecule/atom_positions_dataset.py b/gflownet/utils/molecule/atom_positions_dataset.py deleted file mode 100644 index 0b66f4363..000000000 --- a/gflownet/utils/molecule/atom_positions_dataset.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np - -from gflownet.utils.common import download_file_if_not_exists - - -class AtomPositionsDataset: - def __init__(self, path_to_data, url_to_data): - path_to_data = download_file_if_not_exists(path_to_data, url_to_data) - self.positions = np.load(path_to_data) - - def __getitem__(self, i): - return self.positions[i] - - def __len__(self): - return self.positions.shape[0] - - def sample(self, size=None): - idx = np.random.randint(0, len(self), size=size) - return self.positions[idx] diff --git a/gflownet/utils/molecule/datasets.py b/gflownet/utils/molecule/datasets.py new file mode 100644 index 000000000..eedfc807b --- /dev/null +++ b/gflownet/utils/molecule/datasets.py @@ -0,0 +1,51 @@ +import dgl +import numpy as np + +from gflownet.utils.common import download_file_if_not_exists +from gflownet.utils.molecule import constants +from gflownet.utils.molecule.dgl_conformer import DGLConformer + +class AtomPositionsDataset: + def __init__(self, path_to_data, url_to_data): + path_to_data = download_file_if_not_exists(path_to_data, url_to_data) + self.positions = np.load(path_to_data) + + def __getitem__(self, i): + return self.positions[i] + + def __len__(self): + return self.positions.shape[0] + + def sample(self, size=None): + idx = np.random.randint(0, len(self), size=size) + return self.positions[idx] + +class ConformersDataset: + def __init__(self, path_to_data, url_to_data): + # TODO create a new dataset if path_to_data or url_to_data doesn't exist + path_to_data = download_file_if_not_exists(path_to_data, url_to_data) + with open(path_to_data, 'rb') as inp: + self.conformers = pickle.load(inp) + + def get_conformer(self): + """ + Returns dgl graph with features stored in the dataset: + - ndata: + - atom features + - atomic numbers + - atom position + - edata: + - edge features + - rotatable bonds mask + """ + smiles = np.random.choice(self.conformers.keys()) + edges = self.conformers[smiles]['edges'] + graph = dgl.graph(edges) + graph.ndata[constants.atom_feature_name] = self.conformers[smiles][constants.atom_feature_name] + graph.ndata[constants.atomic_numbers_name] = self.conformers[smiles][constants.atomic_numbers_name] + graph.edata[constants.edge_feature_name] = self.conformers[smiles][constants.edge_feature_name] + graph.edata[constants.rotatable_bonds_mask] = self.conformers[smiles][constants.rotatable_bonds_mask] + conf_idx = np.random.randint(0, self.conformers[smiles][constants.atom_position_name].shape[0]) + graph.ndata[constants.atom_position_name] = self.conformers[smiles][constants.atom_position_name][conf_idx] + conformer = DGLConformer(graph) + return smiles, conformer \ No newline at end of file diff --git a/gflownet/utils/molecule/dgl_conformer.py b/gflownet/utils/molecule/dgl_conformer.py new file mode 100644 index 000000000..acfab419b --- /dev/null +++ b/gflownet/utils/molecule/dgl_conformer.py @@ -0,0 +1,13 @@ +import torch + +class DGLConformer: + def __init__(self, dgl_graph): + self.graph = dgl_graph + + def increment_torsion_angles(self, increments): + raise NotImplementedError + + def randomise_torsion_angles(self): + raise NotImplementedError + + \ No newline at end of file diff --git a/gflownet/utils/molecule/conformer.py b/gflownet/utils/molecule/old_conformer.py similarity index 98% rename from gflownet/utils/molecule/conformer.py rename to gflownet/utils/molecule/old_conformer.py index 689c73db5..b881bf4cf 100644 --- a/gflownet/utils/molecule/conformer.py +++ b/gflownet/utils/molecule/old_conformer.py @@ -10,9 +10,10 @@ from gflownet.utils.molecule import constants from gflownet.utils.molecule.conformer_base import ConformerBase from gflownet.utils.molecule.featurizer import MolDGLFeaturizer +from gflownet.utils.molecule.rdkit_conformer import RDKitConformer -class Conformer(ConformerBase): +class Conformer(RDKitConformer): def __init__(self, atom_positions, smiles, atom_types, freely_rotatable_tas=None): """ :param atom_positions: numpy.ndarray of shape [num_atoms, 3] of dtype float64 diff --git a/gflownet/utils/molecule/conformer_base.py b/gflownet/utils/molecule/rdkit_conformer.py similarity index 96% rename from gflownet/utils/molecule/conformer_base.py rename to gflownet/utils/molecule/rdkit_conformer.py index 7282010e4..47d689447 100644 --- a/gflownet/utils/molecule/conformer_base.py +++ b/gflownet/utils/molecule/rdkit_conformer.py @@ -31,13 +31,15 @@ def get_dummy_ad_atom_positions(): return rconf.GetPositions() -def get_dummy_ad_conf_base(): +def get_dummy_ad_rdkconf(): pos = get_dummy_ad_atom_positions() - conf = ConformerBase(pos, constants.ad_smiles, constants.ad_free_tas) + conf = RDKitConformer( + pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas + ) return conf -class ConformerBase: +class RDKitConformer: def __init__(self, atom_positions, smiles, freely_rotatable_tas=None): """ :param atom_positions: numpy.ndarray of shape [num_atoms, 3] of dtype float64 @@ -141,7 +143,7 @@ def increment_torsion_angle(self, torsion_angle, increment): test_pos = rconf.GetPositions() initial_tas = get_all_torsion_angles(rmol, rconf) - conf = ConformerBase( + conf = RDKitConformer( test_pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas ) # check torsion angles randomisation diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py new file mode 100644 index 000000000..e8d52e682 --- /dev/null +++ b/gflownet/utils/molecule/torsions.py @@ -0,0 +1,27 @@ +import torch +import networkx as nx +import numpy as np + +def get_rotation_masks(dgl_graph): + """ + :param dgl_graph: the dgl.Graph object with bidirected edges in the order: [e_1_fwd, e_1_bkw, e_2_fwd, e_2_bkw, ...] + """ + nx_graph = nx.DiGraph(dgl_graph.to_networkx()) + # bonds are indirected edges + bonds = torch.stack(dgl_graph.edges()).numpy().T[::2] + bonds_mask = np.zeros(bonds.shape[0], dtype=bool) + nodes_mask = np.zeros((bonds.shape[0], dgl_graph.num_nodes()), dtype=bool) + # fill in masks for bonds + for bond_idx, bond in enumerate(bonds): + modified_graph = nx_graph.to_undirected() + modified_graph.remove_edge(*bond) + if not nx.is_connected(modified_graph): + smallest_component_nodes = sorted(nx.connected_components(modified_graph), key=len)[0] + if len(smallest_component_nodes) > 1: + bonds_mask[bond_idx] = True + affected_nodes = np.array(list(smallest_component_nodes - set(bond))) + nodes_mask[bond_idx, affected_nodes] = np.ones_like(affected_nodes, dtype=bool) + # broadcast bond masks to edges masks + edges_mask = torch.from_numpy(bonds_mask.repeat(2)) + nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) + return edges_mask, nodes_mask diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py new file mode 100644 index 000000000..f1e3e65bf --- /dev/null +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -0,0 +1,33 @@ +import pytest +import torch +import dgl + +from gflownet.utils.molecule.torsions import get_rotation_masks + +def test_four_nodes_chain(): + graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) + edges_mask, nodes_mask = get_rotation_masks(graph) + correct_edges_mask = torch.tensor([False, False, True, True, False, False]) + correct_nodes_mask = torch.tensor([[False, False, False, False], + [False, False, False, False], + [ True, False, False, False], + [ True, False, False, False], + [False, False, False, False], + [False, False, False, False]]) + assert torch.all(edges_mask == correct_edges_mask) + assert torch.all(nodes_mask == correct_nodes_mask) + +def test_choose_smallest_component(): + graph = dgl.graph(([0, 2, 1, 2, 2, 3, 3, 4], [2, 0, 2, 1, 3, 2, 4, 3])) + edges_mask, nodes_mask = get_rotation_masks(graph) + correct_edges_mask = torch.tensor([False, False, False, False, True, True, False, False]) + correct_nodes_mask = torch.tensor([[False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, True], + [False, False, False, False, True], + [False, False, False, False, False], + [False, False, False, False, False]]) + assert torch.all(edges_mask == correct_edges_mask) + assert torch.all(nodes_mask == correct_nodes_mask) \ No newline at end of file From 03a445c67d41dec4e771a3c418bd3011b0c203eb Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 28 Feb 2023 18:11:57 -0500 Subject: [PATCH 008/100] add masks to featuraser --- gflownet/utils/molecule/constants.py | 2 ++ gflownet/utils/molecule/featurizer.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/gflownet/utils/molecule/constants.py b/gflownet/utils/molecule/constants.py index d2786b6cb..e74c3616c 100644 --- a/gflownet/utils/molecule/constants.py +++ b/gflownet/utils/molecule/constants.py @@ -6,6 +6,8 @@ edge_feature_name = "edge_features" step_feature_name = "step" atomic_numbers_name = "atomic_numbers" +rotatable_edges_mask_name = "rotatable_edges" +rotation_affected_nodes_mask_name = "rotation_affected_nodes" # Options for atoms featurization ad_atom_types = ("H", "C", "N", "O") diff --git a/gflownet/utils/molecule/featurizer.py b/gflownet/utils/molecule/featurizer.py index f96a82fdd..caa480a8d 100644 --- a/gflownet/utils/molecule/featurizer.py +++ b/gflownet/utils/molecule/featurizer.py @@ -2,6 +2,7 @@ import torch from gflownet.utils.molecule import constants +from gflownet.utils.molecule.torsions import get_rotation_masks class MolDGLFeaturizer: @@ -107,6 +108,9 @@ def mol2dgl(self, mol): graph.ndata[constants.atom_feature_name] = node_features graph.ndata[constants.atomic_numbers_name] = self.get_atomic_numbers(mol) graph.edata[constants.edge_feature_name] = edge_features + edges_mask, nodes_mask = get_rotation_masks(graph) + graph.edata[constants.rotatable_edges_mask_name] = edges_mask + graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask return graph @@ -123,5 +127,6 @@ def mol2dgl(self, mol): print("node features shape:", graph.ndata[constants.atom_feature_name].shape) print("edge features shape:", graph.edata[constants.edge_feature_name].shape) print("edges:", *graph.edges(), sep="\n") + print(graph.edata[constants.rotatable_edges_mask_name]) assert graph.ndata[constants.atom_feature_name].shape[0] == mol.GetNumAtoms() assert graph.edata[constants.edge_feature_name].shape[0] == 2 * mol.GetNumBonds() From b9be53c7dc58b9fb777a33c16480badc2576af9f Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 28 Feb 2023 18:55:38 -0500 Subject: [PATCH 009/100] implemented apply rotations --- gflownet/utils/molecule/dgl_conformer.py | 8 +++++- gflownet/utils/molecule/torsions.py | 33 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/molecule/dgl_conformer.py b/gflownet/utils/molecule/dgl_conformer.py index acfab419b..ce08c6511 100644 --- a/gflownet/utils/molecule/dgl_conformer.py +++ b/gflownet/utils/molecule/dgl_conformer.py @@ -4,7 +4,13 @@ class DGLConformer: def __init__(self, dgl_graph): self.graph = dgl_graph - def increment_torsion_angles(self, increments): + def apply_rotations(self, rotations): + """ + Apply rotations (torsion angles updates) + :param rotations: a sequence of torsion angle updates of length = number of bonds in the molecule. + The order corresponds to the order of edges in self.graph, such that action[i] is + an update for the torsion angle corresponding to the edge[2i] + """ raise NotImplementedError def randomise_torsion_angles(self): diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py index e8d52e682..2f1927405 100644 --- a/gflownet/utils/molecule/torsions.py +++ b/gflownet/utils/molecule/torsions.py @@ -2,6 +2,10 @@ import networkx as nx import numpy as np +from pytorch3d.transforms import axis_angle_to_matrix + +from gflownet.utils.molecule import constants + def get_rotation_masks(dgl_graph): """ :param dgl_graph: the dgl.Graph object with bidirected edges in the order: [e_1_fwd, e_1_bkw, e_2_fwd, e_2_bkw, ...] @@ -25,3 +29,32 @@ def get_rotation_masks(dgl_graph): edges_mask = torch.from_numpy(bonds_mask.repeat(2)) nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) return edges_mask, nodes_mask + +def apply_rotations(dgl_graph, rotations): + """ + Apply rotations (torsion angles updates) + :param dgl_graph: bidirectional dgl.Graph + :param rotations: a sequence of torsion angle updates of length = number of bonds in the molecule. + The order corresponds to the order of edges in the graph, such that action[i] is + an update for the torsion angle corresponding to the edge[2i] + """ + pos = graph.ndata[constants.atom_position_name] + edge_mask = graph.edata[constants.rotatable_edges_mask_name] + node_mask = graph.edata[constants.rotation_affected_nodes_mask_name] + edges = torch.stack(graph.edges()).T + # TODO check how slow it is and whether it's possible to vectorise this loop + for idx_update, update in enumerate(rotations): + idx_edge = idx_update * 2 + if edge_mask[idx_edge]: + begin_pos = pos[edges[idx_edge][0]] + end_pos = pos[edges[idx_edge][1]] + rot_vector = end_pos - begin_pos + rot_vector = rot_vector / torch.linalg.norm(rot_vector) * update + rot_matrix = axis_angle_to_matrix(rot_vector) + x = pos[node_mask[idx_edge]] + pos[node_mask[idx_edge]] = torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos + dgl_graph.ndata[constants.atom_position_name] = pos + return dgl_graph + + + From c80b2f8062766cacd76ee1e53a4b6550a9ed26c6 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Wed, 1 Mar 2023 19:20:56 -0500 Subject: [PATCH 010/100] add simple tests for apply rotations --- gflownet/utils/molecule/torsions.py | 36 ++++- .../gflownet/utils/molecule/test_torsions.py | 140 +++++++++++++++++- 2 files changed, 171 insertions(+), 5 deletions(-) diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py index 2f1927405..06551cd91 100644 --- a/gflownet/utils/molecule/torsions.py +++ b/gflownet/utils/molecule/torsions.py @@ -30,7 +30,7 @@ def get_rotation_masks(dgl_graph): nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) return edges_mask, nodes_mask -def apply_rotations(dgl_graph, rotations): +def apply_rotations(graph, rotations): """ Apply rotations (torsion angles updates) :param dgl_graph: bidirectional dgl.Graph @@ -53,8 +53,38 @@ def apply_rotations(dgl_graph, rotations): rot_matrix = axis_angle_to_matrix(rot_vector) x = pos[node_mask[idx_edge]] pos[node_mask[idx_edge]] = torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos - dgl_graph.ndata[constants.atom_position_name] = pos - return dgl_graph + graph.ndata[constants.atom_position_name] = pos + return graph +if __name__ == '__main__': + from rdkit import Chem + from rdkit.Chem import AllChem + from rdkit.Chem import rdMolTransforms + from rdkit.Chem import TorsionFingerprints + from rdkit.Geometry.rdGeometry import Point3D + from gflownet.utils.molecule.featurizer import MolDGLFeaturizer + from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values + + mol = Chem.MolFromSmiles(constants.ad_smiles) + mol = Chem.AddHs(mol) + AllChem.EmbedMolecule(mol) + rconf = mol.GetConformer() + start_pos = rconf.GetPositions() + + featurizer = MolDGLFeaturizer(constants.ad_atom_types) + graph = featurizer.mol2dgl(mol) + graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) + bonds = torch.stack(graph.edges())[:,::2] + print(bonds) + print(graph.edata[constants.rotatable_edges_mask_name][::2]) + print(bonds[:, graph.edata[constants.rotatable_edges_mask_name][::2]]) + torsion_angles = [(10, 0, 1, 6)] + print(get_torsion_angles_values(rconf, torsion_angles)) + torsion_angles = [(11, 0, 1, 6)] + print(get_torsion_angles_values(rconf, torsion_angles)) + torsion_angles = [(6, 1, 0, 10)] + print(get_torsion_angles_values(rconf, torsion_angles)) + torsion_angles = [(6, 0, 1, 10)] + print(get_torsion_angles_values(rconf, torsion_angles)) diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index f1e3e65bf..e8a1795ce 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -2,7 +2,8 @@ import torch import dgl -from gflownet.utils.molecule.torsions import get_rotation_masks +from gflownet.utils.molecule.torsions import get_rotation_masks, apply_rotations +from gflownet.utils.molecule import constants def test_four_nodes_chain(): graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) @@ -30,4 +31,139 @@ def test_choose_smallest_component(): [False, False, False, False, False], [False, False, False, False, False]]) assert torch.all(edges_mask == correct_edges_mask) - assert torch.all(nodes_mask == correct_nodes_mask) \ No newline at end of file + assert torch.all(nodes_mask == correct_nodes_mask) + +@pytest.mark.parametrize( + "angle, exp_result", + [ + ( + torch.pi / 2, + torch.tensor( + [[1., 0., 1.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi, + torch.tensor( + [[2., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi * 3 / 2, + torch.tensor( + [[1., 0., -1.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi * 2, + torch.tensor( + [[0., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + + ] + +) +def test_apply_rotations_simple(angle, exp_result): + graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) + graph.ndata[constants.atom_position_name] = torch.tensor([ + [0., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.] + ]) + edges_mask, nodes_mask = get_rotation_masks(graph) + graph.edata[constants.rotatable_edges_mask_name] = edges_mask + graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + rotations = torch.tensor([0., angle, 0.]) + result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] + assert torch.allclose(result, exp_result, atol=1e-6) + + +@pytest.mark.parametrize( + "angle, exp_result", + [ + ( + torch.pi / 2, + torch.tensor( + [[1., 0., 1.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi, + torch.tensor( + [[2., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi * 3 / 2, + torch.tensor( + [[1., 0., -1.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + ( + torch.pi * 2, + torch.tensor( + [[0., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.]]) + ), + + ] + +) +def test_apply_rotations_ignore_nonrotatable(angle, exp_result): + graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) + graph.ndata[constants.atom_position_name] = torch.tensor([ + [0., 0., 0.], + [1., 0., 0.], + [1., 1., 0.], + [2., 1., 0.] + ]) + edges_mask, nodes_mask = get_rotation_masks(graph) + graph.edata[constants.rotatable_edges_mask_name] = edges_mask + graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + rotations = torch.tensor([2., angle, -1.]) + result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] + assert torch.allclose(result, exp_result, atol=1e-6) + +# def test_apply_rotation_alanine_dipeptide(): +# from rdkit import Chem + +# mol = Chem.MolFromSmiles(constants.ad_smiles) +# mol = Chem.AddHs(mol) +# AllChem.EmbedMolecule(rmol) +# rconf = rmol.GetConformer() +# start_pos = rconf.GetPositions() + +# featurizer = MolDGLFeaturizer(constants.ad_atom_types) + +# graph = featurizer.mol2dgl(mol) +# graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) + + +# rmol = Chem.MolFromSmiles(constants.ad_smiles) +# rmol = Chem.AddHs(rmol) +# AllChem.EmbedMolecule(rmol) +# rconf = rmol.GetConformer() +# test_pos = rconf.GetPositions() +# initial_tas = get_all_torsion_angles(rmol, rconf) + +# conf = RDKitConformer( +# test_pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas +# ) \ No newline at end of file From ea8e410dbf7bccc1bdc17306ba29d3a0be8e053e Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 2 Mar 2023 17:24:59 -0500 Subject: [PATCH 011/100] fix bug in torsios, add test with AD --- gflownet/utils/molecule/constants.py | 1 + gflownet/utils/molecule/featurizer.py | 3 +- gflownet/utils/molecule/torsions.py | 45 ++------- .../gflownet/utils/molecule/test_torsions.py | 92 ++++++++++++------- 4 files changed, 71 insertions(+), 70 deletions(-) diff --git a/gflownet/utils/molecule/constants.py b/gflownet/utils/molecule/constants.py index e74c3616c..ce963c3a1 100644 --- a/gflownet/utils/molecule/constants.py +++ b/gflownet/utils/molecule/constants.py @@ -8,6 +8,7 @@ atomic_numbers_name = "atomic_numbers" rotatable_edges_mask_name = "rotatable_edges" rotation_affected_nodes_mask_name = "rotation_affected_nodes" +rotation_signs_name = "rotation_signs" # Options for atoms featurization ad_atom_types = ("H", "C", "N", "O") diff --git a/gflownet/utils/molecule/featurizer.py b/gflownet/utils/molecule/featurizer.py index caa480a8d..51be811db 100644 --- a/gflownet/utils/molecule/featurizer.py +++ b/gflownet/utils/molecule/featurizer.py @@ -108,9 +108,10 @@ def mol2dgl(self, mol): graph.ndata[constants.atom_feature_name] = node_features graph.ndata[constants.atomic_numbers_name] = self.get_atomic_numbers(mol) graph.edata[constants.edge_feature_name] = edge_features - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) graph.edata[constants.rotatable_edges_mask_name] = edges_mask graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + graph.edata[constants.rotation_signs_name] = rotation_signs return graph diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py index 06551cd91..b210c0e2e 100644 --- a/gflownet/utils/molecule/torsions.py +++ b/gflownet/utils/molecule/torsions.py @@ -15,6 +15,7 @@ def get_rotation_masks(dgl_graph): bonds = torch.stack(dgl_graph.edges()).numpy().T[::2] bonds_mask = np.zeros(bonds.shape[0], dtype=bool) nodes_mask = np.zeros((bonds.shape[0], dgl_graph.num_nodes()), dtype=bool) + rotation_signs = np.zeros(bonds.shape[0], dtype=float) # fill in masks for bonds for bond_idx, bond in enumerate(bonds): modified_graph = nx_graph.to_undirected() @@ -23,12 +24,15 @@ def get_rotation_masks(dgl_graph): smallest_component_nodes = sorted(nx.connected_components(modified_graph), key=len)[0] if len(smallest_component_nodes) > 1: bonds_mask[bond_idx] = True + rotation_signs[bond_idx] = -1 if bond[0] in smallest_component_nodes else 1 affected_nodes = np.array(list(smallest_component_nodes - set(bond))) nodes_mask[bond_idx, affected_nodes] = np.ones_like(affected_nodes, dtype=bool) + # broadcast bond masks to edges masks edges_mask = torch.from_numpy(bonds_mask.repeat(2)) + rotation_signs = torch.from_numpy(rotation_signs.repeat(2)) nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) - return edges_mask, nodes_mask + return edges_mask, nodes_mask, rotation_signs def apply_rotations(graph, rotations): """ @@ -41,50 +45,19 @@ def apply_rotations(graph, rotations): pos = graph.ndata[constants.atom_position_name] edge_mask = graph.edata[constants.rotatable_edges_mask_name] node_mask = graph.edata[constants.rotation_affected_nodes_mask_name] + rot_signs = graph.edata[constants.rotation_signs_name] edges = torch.stack(graph.edges()).T # TODO check how slow it is and whether it's possible to vectorise this loop for idx_update, update in enumerate(rotations): + # import ipdb; ipdb.set_trace() idx_edge = idx_update * 2 if edge_mask[idx_edge]: begin_pos = pos[edges[idx_edge][0]] end_pos = pos[edges[idx_edge][1]] rot_vector = end_pos - begin_pos - rot_vector = rot_vector / torch.linalg.norm(rot_vector) * update + rot_vector = rot_vector / torch.linalg.norm(rot_vector) * update * rot_signs[idx_edge] rot_matrix = axis_angle_to_matrix(rot_vector) x = pos[node_mask[idx_edge]] pos[node_mask[idx_edge]] = torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos graph.ndata[constants.atom_position_name] = pos - return graph - - - -if __name__ == '__main__': - from rdkit import Chem - from rdkit.Chem import AllChem - from rdkit.Chem import rdMolTransforms - from rdkit.Chem import TorsionFingerprints - from rdkit.Geometry.rdGeometry import Point3D - from gflownet.utils.molecule.featurizer import MolDGLFeaturizer - from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values - - mol = Chem.MolFromSmiles(constants.ad_smiles) - mol = Chem.AddHs(mol) - AllChem.EmbedMolecule(mol) - rconf = mol.GetConformer() - start_pos = rconf.GetPositions() - - featurizer = MolDGLFeaturizer(constants.ad_atom_types) - graph = featurizer.mol2dgl(mol) - graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) - bonds = torch.stack(graph.edges())[:,::2] - print(bonds) - print(graph.edata[constants.rotatable_edges_mask_name][::2]) - print(bonds[:, graph.edata[constants.rotatable_edges_mask_name][::2]]) - torsion_angles = [(10, 0, 1, 6)] - print(get_torsion_angles_values(rconf, torsion_angles)) - torsion_angles = [(11, 0, 1, 6)] - print(get_torsion_angles_values(rconf, torsion_angles)) - torsion_angles = [(6, 1, 0, 10)] - print(get_torsion_angles_values(rconf, torsion_angles)) - torsion_angles = [(6, 0, 1, 10)] - print(get_torsion_angles_values(rconf, torsion_angles)) + return graph \ No newline at end of file diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index e8a1795ce..b5ba74997 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -2,12 +2,18 @@ import torch import dgl +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Geometry.rdGeometry import Point3D + from gflownet.utils.molecule.torsions import get_rotation_masks, apply_rotations from gflownet.utils.molecule import constants +from gflownet.utils.molecule.featurizer import MolDGLFeaturizer +from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values def test_four_nodes_chain(): graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) correct_edges_mask = torch.tensor([False, False, True, True, False, False]) correct_nodes_mask = torch.tensor([[False, False, False, False], [False, False, False, False], @@ -15,12 +21,14 @@ def test_four_nodes_chain(): [ True, False, False, False], [False, False, False, False], [False, False, False, False]]) + correct_rotation_signs = torch.tensor([ 0., 0., -1., -1., 0., 0.]) assert torch.all(edges_mask == correct_edges_mask) assert torch.all(nodes_mask == correct_nodes_mask) + assert torch.all(rotation_signs == correct_rotation_signs) def test_choose_smallest_component(): graph = dgl.graph(([0, 2, 1, 2, 2, 3, 3, 4], [2, 0, 2, 1, 3, 2, 4, 3])) - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) correct_edges_mask = torch.tensor([False, False, False, False, True, True, False, False]) correct_nodes_mask = torch.tensor([[False, False, False, False, False], [False, False, False, False, False], @@ -30,8 +38,10 @@ def test_choose_smallest_component(): [False, False, False, False, True], [False, False, False, False, False], [False, False, False, False, False]]) + correct_rotation_signs = torch.tensor([0., 0., 0., 0., 1., 1., 0., 0.]) assert torch.all(edges_mask == correct_edges_mask) assert torch.all(nodes_mask == correct_nodes_mask) + assert torch.all(rotation_signs == correct_rotation_signs) @pytest.mark.parametrize( "angle, exp_result", @@ -39,7 +49,7 @@ def test_choose_smallest_component(): ( torch.pi / 2, torch.tensor( - [[1., 0., 1.], + [[1., 0., -1.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]]) @@ -55,7 +65,7 @@ def test_choose_smallest_component(): ( torch.pi * 3 / 2, torch.tensor( - [[1., 0., -1.], + [[1., 0., 1.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]]) @@ -80,9 +90,10 @@ def test_apply_rotations_simple(angle, exp_result): [1., 1., 0.], [2., 1., 0.] ]) - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) graph.edata[constants.rotatable_edges_mask_name] = edges_mask graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + graph.edata[constants.rotation_signs_name] = rotation_signs rotations = torch.tensor([0., angle, 0.]) result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] assert torch.allclose(result, exp_result, atol=1e-6) @@ -94,7 +105,7 @@ def test_apply_rotations_simple(angle, exp_result): ( torch.pi / 2, torch.tensor( - [[1., 0., 1.], + [[1., 0., -1.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]]) @@ -110,7 +121,7 @@ def test_apply_rotations_simple(angle, exp_result): ( torch.pi * 3 / 2, torch.tensor( - [[1., 0., -1.], + [[1., 0., 1.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]]) @@ -135,35 +146,50 @@ def test_apply_rotations_ignore_nonrotatable(angle, exp_result): [1., 1., 0.], [2., 1., 0.] ]) - edges_mask, nodes_mask = get_rotation_masks(graph) + edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) graph.edata[constants.rotatable_edges_mask_name] = edges_mask graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask + graph.edata[constants.rotation_signs_name] = rotation_signs rotations = torch.tensor([2., angle, -1.]) result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] assert torch.allclose(result, exp_result, atol=1e-6) -# def test_apply_rotation_alanine_dipeptide(): -# from rdkit import Chem - -# mol = Chem.MolFromSmiles(constants.ad_smiles) -# mol = Chem.AddHs(mol) -# AllChem.EmbedMolecule(rmol) -# rconf = rmol.GetConformer() -# start_pos = rconf.GetPositions() - -# featurizer = MolDGLFeaturizer(constants.ad_atom_types) - -# graph = featurizer.mol2dgl(mol) -# graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) - - -# rmol = Chem.MolFromSmiles(constants.ad_smiles) -# rmol = Chem.AddHs(rmol) -# AllChem.EmbedMolecule(rmol) -# rconf = rmol.GetConformer() -# test_pos = rconf.GetPositions() -# initial_tas = get_all_torsion_angles(rmol, rconf) - -# conf = RDKitConformer( -# test_pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas -# ) \ No newline at end of file +def stress_test_apply_rotation_alanine_dipeptide(): + from rdkit import Chem + from rdkit.Chem import AllChem + from rdkit.Geometry.rdGeometry import Point3D + from gflownet.utils.molecule.featurizer import MolDGLFeaturizer + from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values + + mol = Chem.MolFromSmiles(constants.ad_smiles) + mol = Chem.AddHs(mol) + AllChem.EmbedMolecule(mol) + rconf = mol.GetConformer() + start_pos = rconf.GetPositions() + + featurizer = MolDGLFeaturizer(constants.ad_atom_types) + graph = featurizer.mol2dgl(mol) + graph.ndata[constants.atom_position_name] = torch.from_numpy(start_pos) + + torsion_angles = [ + (10, 0, 1, 6), + (0, 1, 2, 3), + (1, 2, 4, 14), + (2, 4, 5, 15), + (0, 1, 6, 7), + (18, 6, 7, 8), + (8, 7, 9, 19) + ] + n_edges = graph.edges()[0].shape[-1] + for _ in range (100): + ta_initial_values = torch.tensor(get_torsion_angles_values(rconf, torsion_angles)) + + rotations = torch.rand(n_edges // 2) * torch.pi * 2 + graph = apply_rotations(graph, rotations) + new_pos = graph.ndata[constants.atom_position_name].numpy() + for idx, pos in enumerate(new_pos): + rconf.SetAtomPosition(idx, Point3D(*pos)) + ta_updated_values = torch.tensor(get_torsion_angles_values(rconf, torsion_angles)) + valid_rotations = rotations[graph.edata[constants.rotatable_edges_mask_name][::2]] + diff = (ta_updated_values - ta_initial_values - valid_rotations) % (2*torch.pi) + assert torch.logical_or(torch.isclose(diff, torch.zeros_like(diff), atol=1e-6), torch.isclose(diff, torch.ones_like(diff)*2*torch.pi, atol=1e-5)).all() \ No newline at end of file From 7cf2f3eb1eacd0f10a138060ace5bbd86a71a94a Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 11 Apr 2023 18:07:43 -0400 Subject: [PATCH 012/100] add a comment to ConformerDataset --- gflownet/utils/molecule/datasets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/molecule/datasets.py b/gflownet/utils/molecule/datasets.py index eedfc807b..cfb43dba4 100644 --- a/gflownet/utils/molecule/datasets.py +++ b/gflownet/utils/molecule/datasets.py @@ -38,6 +38,7 @@ def get_conformer(self): - edge features - rotatable bonds mask """ + # TODO make it work if there're several conformers for a single molecule smiles = np.random.choice(self.conformers.keys()) edges = self.conformers[smiles]['edges'] graph = dgl.graph(edges) @@ -48,4 +49,4 @@ def get_conformer(self): conf_idx = np.random.randint(0, self.conformers[smiles][constants.atom_position_name].shape[0]) graph.ndata[constants.atom_position_name] = self.conformers[smiles][constants.atom_position_name][conf_idx] conformer = DGLConformer(graph) - return smiles, conformer \ No newline at end of file + return smiles, conformer From 38afb292be121d58b4a17fa36ca30bdd90f9f8c8 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 26 Apr 2023 10:27:42 -0400 Subject: [PATCH 013/100] updated setup --- setup_conformer.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup_conformer.sh b/setup_conformer.sh index 0ef7c720b..34ed6e3fc 100644 --- a/setup_conformer.sh +++ b/setup_conformer.sh @@ -17,8 +17,8 @@ python -m pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spl # Install DGL (see https://www.dgl.ai/pages/start.html) - giving problems python -m pip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html # Requirements to run -python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani +python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani pytorch3d # Conditional requirements -python -m pip install wandb matplotlib plotly gdown +python -m pip install wandb matplotlib plotly pymatgen gdown # Dev packages # python -m pip install black flake8 isort pylint ipdb jupyter pytest pytest-repeat From f0d8225906f32813863b806ad7c161c09b29cd5b Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 26 Apr 2023 10:31:13 -0400 Subject: [PATCH 014/100] fixing tests --- gflownet/utils/molecule/rdkit_conformer.py | 2 +- tests/gflownet/proxy/test_molecule.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gflownet/utils/molecule/rdkit_conformer.py b/gflownet/utils/molecule/rdkit_conformer.py index 47d689447..140e8b1b3 100644 --- a/gflownet/utils/molecule/rdkit_conformer.py +++ b/gflownet/utils/molecule/rdkit_conformer.py @@ -34,7 +34,7 @@ def get_dummy_ad_atom_positions(): def get_dummy_ad_rdkconf(): pos = get_dummy_ad_atom_positions() conf = RDKitConformer( - pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas + pos, constants.ad_smiles, constants.ad_free_tas ) return conf diff --git a/tests/gflownet/proxy/test_molecule.py b/tests/gflownet/proxy/test_molecule.py index 02c27e76f..07f7a70c3 100644 --- a/tests/gflownet/proxy/test_molecule.py +++ b/tests/gflownet/proxy/test_molecule.py @@ -3,7 +3,7 @@ import torch from gflownet.proxy.molecule import TorchANIMoleculeEnergy -from gflownet.utils.molecule.conformer_base import get_dummy_ad_conf_base +from gflownet.utils.molecule.rdkit_conformer import get_dummy_ad_rdkconf @pytest.fixture() @@ -12,14 +12,14 @@ def proxy(): def test__torchani_molecule_energy__predicts_energy_for_a_single_numpy_conformer(proxy): - conf = get_dummy_ad_conf_base() + conf = get_dummy_ad_rdkconf() coordinates, elements = conf.get_atom_positions(), conf.get_atomic_numbers() proxy(elements[np.newaxis, ...], coordinates[np.newaxis, ...]) def test__torchani_molecule_energy__predicts_energy_for_a_pytorch_batch(proxy): - conf = get_dummy_ad_conf_base() + conf = get_dummy_ad_rdkconf() coordinates, elements = conf.get_atom_positions(), conf.get_atomic_numbers() coordinates = torch.Tensor(coordinates).repeat(3, 1, 1) From 8b7502ee2c4b666844f3e1a38b22f5a22e490d88 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 26 Apr 2023 10:31:46 -0400 Subject: [PATCH 015/100] black --- gflownet/utils/molecule/datasets.py | 34 ++-- gflownet/utils/molecule/dgl_conformer.py | 3 +- gflownet/utils/molecule/rdkit_conformer.py | 4 +- gflownet/utils/molecule/torsions.py | 33 +++- .../gflownet/utils/molecule/test_torsions.py | 165 +++++++++--------- 5 files changed, 132 insertions(+), 107 deletions(-) diff --git a/gflownet/utils/molecule/datasets.py b/gflownet/utils/molecule/datasets.py index cfb43dba4..f87fe7b22 100644 --- a/gflownet/utils/molecule/datasets.py +++ b/gflownet/utils/molecule/datasets.py @@ -5,6 +5,7 @@ from gflownet.utils.molecule import constants from gflownet.utils.molecule.dgl_conformer import DGLConformer + class AtomPositionsDataset: def __init__(self, path_to_data, url_to_data): path_to_data = download_file_if_not_exists(path_to_data, url_to_data) @@ -20,17 +21,18 @@ def sample(self, size=None): idx = np.random.randint(0, len(self), size=size) return self.positions[idx] + class ConformersDataset: def __init__(self, path_to_data, url_to_data): # TODO create a new dataset if path_to_data or url_to_data doesn't exist path_to_data = download_file_if_not_exists(path_to_data, url_to_data) - with open(path_to_data, 'rb') as inp: + with open(path_to_data, "rb") as inp: self.conformers = pickle.load(inp) - + def get_conformer(self): """ Returns dgl graph with features stored in the dataset: - - ndata: + - ndata: - atom features - atomic numbers - atom position @@ -40,13 +42,25 @@ def get_conformer(self): """ # TODO make it work if there're several conformers for a single molecule smiles = np.random.choice(self.conformers.keys()) - edges = self.conformers[smiles]['edges'] + edges = self.conformers[smiles]["edges"] graph = dgl.graph(edges) - graph.ndata[constants.atom_feature_name] = self.conformers[smiles][constants.atom_feature_name] - graph.ndata[constants.atomic_numbers_name] = self.conformers[smiles][constants.atomic_numbers_name] - graph.edata[constants.edge_feature_name] = self.conformers[smiles][constants.edge_feature_name] - graph.edata[constants.rotatable_bonds_mask] = self.conformers[smiles][constants.rotatable_bonds_mask] - conf_idx = np.random.randint(0, self.conformers[smiles][constants.atom_position_name].shape[0]) - graph.ndata[constants.atom_position_name] = self.conformers[smiles][constants.atom_position_name][conf_idx] + graph.ndata[constants.atom_feature_name] = self.conformers[smiles][ + constants.atom_feature_name + ] + graph.ndata[constants.atomic_numbers_name] = self.conformers[smiles][ + constants.atomic_numbers_name + ] + graph.edata[constants.edge_feature_name] = self.conformers[smiles][ + constants.edge_feature_name + ] + graph.edata[constants.rotatable_bonds_mask] = self.conformers[smiles][ + constants.rotatable_bonds_mask + ] + conf_idx = np.random.randint( + 0, self.conformers[smiles][constants.atom_position_name].shape[0] + ) + graph.ndata[constants.atom_position_name] = self.conformers[smiles][ + constants.atom_position_name + ][conf_idx] conformer = DGLConformer(graph) return smiles, conformer diff --git a/gflownet/utils/molecule/dgl_conformer.py b/gflownet/utils/molecule/dgl_conformer.py index ce08c6511..990ab9d46 100644 --- a/gflownet/utils/molecule/dgl_conformer.py +++ b/gflownet/utils/molecule/dgl_conformer.py @@ -1,5 +1,6 @@ import torch + class DGLConformer: def __init__(self, dgl_graph): self.graph = dgl_graph @@ -15,5 +16,3 @@ def apply_rotations(self, rotations): def randomise_torsion_angles(self): raise NotImplementedError - - \ No newline at end of file diff --git a/gflownet/utils/molecule/rdkit_conformer.py b/gflownet/utils/molecule/rdkit_conformer.py index 140e8b1b3..330f493ba 100644 --- a/gflownet/utils/molecule/rdkit_conformer.py +++ b/gflownet/utils/molecule/rdkit_conformer.py @@ -33,9 +33,7 @@ def get_dummy_ad_atom_positions(): def get_dummy_ad_rdkconf(): pos = get_dummy_ad_atom_positions() - conf = RDKitConformer( - pos, constants.ad_smiles, constants.ad_free_tas - ) + conf = RDKitConformer(pos, constants.ad_smiles, constants.ad_free_tas) return conf diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py index b210c0e2e..a84146933 100644 --- a/gflownet/utils/molecule/torsions.py +++ b/gflownet/utils/molecule/torsions.py @@ -6,9 +6,10 @@ from gflownet.utils.molecule import constants + def get_rotation_masks(dgl_graph): """ - :param dgl_graph: the dgl.Graph object with bidirected edges in the order: [e_1_fwd, e_1_bkw, e_2_fwd, e_2_bkw, ...] + :param dgl_graph: the dgl.Graph object with bidirected edges in the order: [e_1_fwd, e_1_bkw, e_2_fwd, e_2_bkw, ...] """ nx_graph = nx.DiGraph(dgl_graph.to_networkx()) # bonds are indirected edges @@ -16,24 +17,31 @@ def get_rotation_masks(dgl_graph): bonds_mask = np.zeros(bonds.shape[0], dtype=bool) nodes_mask = np.zeros((bonds.shape[0], dgl_graph.num_nodes()), dtype=bool) rotation_signs = np.zeros(bonds.shape[0], dtype=float) - # fill in masks for bonds + # fill in masks for bonds for bond_idx, bond in enumerate(bonds): modified_graph = nx_graph.to_undirected() modified_graph.remove_edge(*bond) if not nx.is_connected(modified_graph): - smallest_component_nodes = sorted(nx.connected_components(modified_graph), key=len)[0] + smallest_component_nodes = sorted( + nx.connected_components(modified_graph), key=len + )[0] if len(smallest_component_nodes) > 1: bonds_mask[bond_idx] = True - rotation_signs[bond_idx] = -1 if bond[0] in smallest_component_nodes else 1 + rotation_signs[bond_idx] = ( + -1 if bond[0] in smallest_component_nodes else 1 + ) affected_nodes = np.array(list(smallest_component_nodes - set(bond))) - nodes_mask[bond_idx, affected_nodes] = np.ones_like(affected_nodes, dtype=bool) - + nodes_mask[bond_idx, affected_nodes] = np.ones_like( + affected_nodes, dtype=bool + ) + # broadcast bond masks to edges masks edges_mask = torch.from_numpy(bonds_mask.repeat(2)) rotation_signs = torch.from_numpy(rotation_signs.repeat(2)) nodes_mask = torch.from_numpy(nodes_mask.repeat(2, axis=0)) return edges_mask, nodes_mask, rotation_signs + def apply_rotations(graph, rotations): """ Apply rotations (torsion angles updates) @@ -55,9 +63,16 @@ def apply_rotations(graph, rotations): begin_pos = pos[edges[idx_edge][0]] end_pos = pos[edges[idx_edge][1]] rot_vector = end_pos - begin_pos - rot_vector = rot_vector / torch.linalg.norm(rot_vector) * update * rot_signs[idx_edge] + rot_vector = ( + rot_vector + / torch.linalg.norm(rot_vector) + * update + * rot_signs[idx_edge] + ) rot_matrix = axis_angle_to_matrix(rot_vector) x = pos[node_mask[idx_edge]] - pos[node_mask[idx_edge]] = torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos + pos[node_mask[idx_edge]] = ( + torch.matmul((x - begin_pos), rot_matrix.T) + begin_pos + ) graph.ndata[constants.atom_position_name] = pos - return graph \ No newline at end of file + return graph diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index b5ba74997..dba8f7ea5 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -11,92 +11,92 @@ from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values + def test_four_nodes_chain(): - graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) + graph = dgl.graph(([0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2])) edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) - correct_edges_mask = torch.tensor([False, False, True, True, False, False]) - correct_nodes_mask = torch.tensor([[False, False, False, False], - [False, False, False, False], - [ True, False, False, False], - [ True, False, False, False], - [False, False, False, False], - [False, False, False, False]]) - correct_rotation_signs = torch.tensor([ 0., 0., -1., -1., 0., 0.]) + correct_edges_mask = torch.tensor([False, False, True, True, False, False]) + correct_nodes_mask = torch.tensor( + [ + [False, False, False, False], + [False, False, False, False], + [True, False, False, False], + [True, False, False, False], + [False, False, False, False], + [False, False, False, False], + ] + ) + correct_rotation_signs = torch.tensor([0.0, 0.0, -1.0, -1.0, 0.0, 0.0]) assert torch.all(edges_mask == correct_edges_mask) assert torch.all(nodes_mask == correct_nodes_mask) assert torch.all(rotation_signs == correct_rotation_signs) + def test_choose_smallest_component(): graph = dgl.graph(([0, 2, 1, 2, 2, 3, 3, 4], [2, 0, 2, 1, 3, 2, 4, 3])) edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) - correct_edges_mask = torch.tensor([False, False, False, False, True, True, False, False]) - correct_nodes_mask = torch.tensor([[False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, False, True], - [False, False, False, False, True], - [False, False, False, False, False], - [False, False, False, False, False]]) - correct_rotation_signs = torch.tensor([0., 0., 0., 0., 1., 1., 0., 0.]) + correct_edges_mask = torch.tensor( + [False, False, False, False, True, True, False, False] + ) + correct_nodes_mask = torch.tensor( + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, True], + [False, False, False, False, True], + [False, False, False, False, False], + [False, False, False, False, False], + ] + ) + correct_rotation_signs = torch.tensor([0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0]) assert torch.all(edges_mask == correct_edges_mask) assert torch.all(nodes_mask == correct_nodes_mask) assert torch.all(rotation_signs == correct_rotation_signs) + @pytest.mark.parametrize( "angle, exp_result", [ ( torch.pi / 2, torch.tensor( - [[1., 0., -1.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.]]) + [[1.0, 0.0, -1.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ), ), ( torch.pi, torch.tensor( - [[2., 0., 0.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.]]) + [[2.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ), ), ( torch.pi * 3 / 2, torch.tensor( - [[1., 0., 1.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.]]) + [[1.0, 0.0, 1.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ), ), ( torch.pi * 2, torch.tensor( - [[0., 0., 0.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.]]) + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ), ), - - ] - + ], ) def test_apply_rotations_simple(angle, exp_result): - graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) - graph.ndata[constants.atom_position_name] = torch.tensor([ - [0., 0., 0.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.] - ]) + graph = dgl.graph(([0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2])) + graph.ndata[constants.atom_position_name] = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ) edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) graph.edata[constants.rotatable_edges_mask_name] = edges_mask graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask graph.edata[constants.rotation_signs_name] = rotation_signs - rotations = torch.tensor([0., angle, 0.]) + rotations = torch.tensor([0.0, angle, 0.0]) result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] - assert torch.allclose(result, exp_result, atol=1e-6) + assert torch.allclose(result, exp_result, atol=1e-6) @pytest.mark.parametrize( @@ -105,54 +105,42 @@ def test_apply_rotations_simple(angle, exp_result): ( torch.pi / 2, torch.tensor( - [[1., 0., -1.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.]]) + [[1.0, 0.0, -1.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ), ), ( torch.pi, torch.tensor( - [[2., 0., 0.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.]]) + [[2.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ), ), ( torch.pi * 3 / 2, torch.tensor( - [[1., 0., 1.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.]]) + [[1.0, 0.0, 1.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ), ), ( torch.pi * 2, torch.tensor( - [[0., 0., 0.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.]]) + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ), ), - - ] - + ], ) def test_apply_rotations_ignore_nonrotatable(angle, exp_result): - graph = dgl.graph(([0,1,1,2,2,3], [1,0,2,1,3,2])) - graph.ndata[constants.atom_position_name] = torch.tensor([ - [0., 0., 0.], - [1., 0., 0.], - [1., 1., 0.], - [2., 1., 0.] - ]) + graph = dgl.graph(([0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2])) + graph.ndata[constants.atom_position_name] = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]] + ) edges_mask, nodes_mask, rotation_signs = get_rotation_masks(graph) graph.edata[constants.rotatable_edges_mask_name] = edges_mask graph.edata[constants.rotation_affected_nodes_mask_name] = nodes_mask graph.edata[constants.rotation_signs_name] = rotation_signs - rotations = torch.tensor([2., angle, -1.]) + rotations = torch.tensor([2.0, angle, -1.0]) result = apply_rotations(graph, rotations).ndata[constants.atom_position_name] - assert torch.allclose(result, exp_result, atol=1e-6) + assert torch.allclose(result, exp_result, atol=1e-6) + def stress_test_apply_rotation_alanine_dipeptide(): from rdkit import Chem @@ -178,18 +166,29 @@ def stress_test_apply_rotation_alanine_dipeptide(): (2, 4, 5, 15), (0, 1, 6, 7), (18, 6, 7, 8), - (8, 7, 9, 19) + (8, 7, 9, 19), ] n_edges = graph.edges()[0].shape[-1] - for _ in range (100): - ta_initial_values = torch.tensor(get_torsion_angles_values(rconf, torsion_angles)) - + for _ in range(100): + ta_initial_values = torch.tensor( + get_torsion_angles_values(rconf, torsion_angles) + ) + rotations = torch.rand(n_edges // 2) * torch.pi * 2 graph = apply_rotations(graph, rotations) new_pos = graph.ndata[constants.atom_position_name].numpy() for idx, pos in enumerate(new_pos): rconf.SetAtomPosition(idx, Point3D(*pos)) - ta_updated_values = torch.tensor(get_torsion_angles_values(rconf, torsion_angles)) - valid_rotations = rotations[graph.edata[constants.rotatable_edges_mask_name][::2]] - diff = (ta_updated_values - ta_initial_values - valid_rotations) % (2*torch.pi) - assert torch.logical_or(torch.isclose(diff, torch.zeros_like(diff), atol=1e-6), torch.isclose(diff, torch.ones_like(diff)*2*torch.pi, atol=1e-5)).all() \ No newline at end of file + ta_updated_values = torch.tensor( + get_torsion_angles_values(rconf, torsion_angles) + ) + valid_rotations = rotations[ + graph.edata[constants.rotatable_edges_mask_name][::2] + ] + diff = (ta_updated_values - ta_initial_values - valid_rotations) % ( + 2 * torch.pi + ) + assert torch.logical_or( + torch.isclose(diff, torch.zeros_like(diff), atol=1e-6), + torch.isclose(diff, torch.ones_like(diff) * 2 * torch.pi, atol=1e-5), + ).all() From f0d53abbf219c6db73008f829bc2489a7d493174 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 16 May 2023 17:32:35 -0400 Subject: [PATCH 016/100] WiP molecule TorchANI proxy --- .../experiments/icml23/alaninedipeptide.yaml | 2 +- config/proxy/molecule.yaml | 4 -- gflownet/envs/__init__.py | 2 +- gflownet/proxy/molecule.py | 69 ++++++++++++------- 4 files changed, 48 insertions(+), 29 deletions(-) delete mode 100644 config/proxy/molecule.yaml diff --git a/config/experiments/icml23/alaninedipeptide.yaml b/config/experiments/icml23/alaninedipeptide.yaml index 6725d7b42..597fb33e6 100644 --- a/config/experiments/icml23/alaninedipeptide.yaml +++ b/config/experiments/icml23/alaninedipeptide.yaml @@ -3,7 +3,7 @@ defaults: - override /env: alaninedipeptide - override /gflownet: trajectorybalance - - override /proxy: molecule + - override /proxy: molecule_rf - override /logger: wandb - override /user: sasha diff --git a/config/proxy/molecule.yaml b/config/proxy/molecule.yaml deleted file mode 100644 index 1283c09e7..000000000 --- a/config/proxy/molecule.yaml +++ /dev/null @@ -1,4 +0,0 @@ -_target_: gflownet.proxy.molecule.RFMoleculeEnergy - -path_to_model: './data/random_forest_reward_100.pkl' -url_to_model: 'https://drive.google.com/uc?id=1OpQNC8WWIsMh8K4olfSaQRFlj3emYThF' \ No newline at end of file diff --git a/gflownet/envs/__init__.py b/gflownet/envs/__init__.py index edbe3a52c..69716b11b 100644 --- a/gflownet/envs/__init__.py +++ b/gflownet/envs/__init__.py @@ -1 +1 @@ -__all__ = ["base", "grid", "aptamers"] +__all__ = ["base", "grid", "aptamers", "alaninedipeptide"] diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 662eea0ff..d661b464d 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -1,11 +1,12 @@ import pickle -from typing import Union +from copy import deepcopy +from typing import Iterable, List import numpy as np -import numpy.typing as npt import torch import torchani from sklearn.ensemble import RandomForestRegressor +from torch import FloatTensor, LongTensor, Tensor from gflownet.proxy.base import Proxy from gflownet.utils.common import download_file_if_not_exists @@ -49,7 +50,13 @@ def __deepcopy__(self, memo): class TorchANIMoleculeEnergy(Proxy): - def __init__(self, model: str = "ANI2x", use_ensemble: bool = True, **kwargs): + def __init__( + self, + model: str = "ANI1x", + use_ensemble: bool = False, + batch_size: int = 8, + **kwargs, + ): """ Parameters ---------- @@ -58,9 +65,15 @@ def __init__(self, model: str = "ANI2x", use_ensemble: bool = True, **kwargs): use_ensemble : bool Whether to use whole ensemble of the models for prediction or only the first one. + + batch_size : int + Batch size for TorchANI. """ super().__init__(**kwargs) + self.batch_size = batch_size + self.min = -500 + if TORCHANI_MODELS.get(model) is None: raise ValueError( f'Tried to use model "{model}", ' @@ -71,35 +84,45 @@ def __init__(self, model: str = "ANI2x", use_ensemble: bool = True, **kwargs): periodic_table_index=True, model_index=None if use_ensemble else 0 ).to(self.device) + def setup(self, env=None): + self.conformer = env.conformer # deepcopy(env.conformer) + + def _sync_conformer_with_state(self, state: List): + for idx, ta in enumerate(self.conformer.freely_rotatable_tas): + self.conformer.set_torsion_angle(ta, state[idx]) + return self.conformer + @torch.no_grad() - def __call__( - self, - elements: Union[npt.NDArray[np.int64], torch.LongTensor], - coordinates: Union[npt.NDArray[np.float32], torch.FloatTensor], - ) -> npt.NDArray[np.float32]: + def __call__(self, states: Iterable) -> Tensor: """ Args ---- - elements : tensor - Either numpy or torch tensor with dimensionality (batch_size, n_atoms), - with values indicating atomic number of individual atoms. - - coordinates : tensor - Either numpy or torch tensor with dimensionality (batch_size, n_atoms, 3), - with values indicating 3D positions of individual atoms. + states + An iterable of states in AlanineDipeptide environment format (torsion angles). Returns ---- energies : tensor - Numpy array with dimensionality (batch_size,), containing energies + Torch with dimensionality (batch_size,), containing energies predicted by a TorchANI model (in Hartree). """ - if isinstance(elements, np.ndarray): - elements = torch.from_numpy(elements) - if isinstance(coordinates, np.ndarray): - coordinates = torch.from_numpy(coordinates) + elements = [] + coordinates = [] + + for st in states: + conf = self._sync_conformer_with_state(st) + + elements.append(conf.get_atomic_numbers()) + coordinates.append(conf.get_atom_positions()) + + elements = LongTensor(np.array(elements)).to(self.device) + coordinates = FloatTensor(np.array(coordinates)).to(self.device) - elements = elements.long().to(self.device) - coordinates = coordinates.float().to(self.device) + energies = [] + for elements_batch, coordinates_batch in zip( + torch.split(elements, self.batch_size), + torch.split(coordinates, self.batch_size), + ): + energies.append(self.model((elements_batch, coordinates_batch)).energies) - return self.model((elements, coordinates)).energies.cpu().numpy() + return torch.cat(energies).float() From e87e441655790f3e02a1db14056aa6c7ac9b0878 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 17 May 2023 12:08:35 -0400 Subject: [PATCH 017/100] added configs --- .../alaninedipeptide_torchani.yaml | 57 +++++++++++++++++++ config/proxy/molecule_rf.yaml | 4 ++ config/proxy/molecule_torchani.yaml | 1 + 3 files changed, 62 insertions(+) create mode 100644 config/experiments/alaninedipeptide_torchani.yaml create mode 100644 config/proxy/molecule_rf.yaml create mode 100644 config/proxy/molecule_torchani.yaml diff --git a/config/experiments/alaninedipeptide_torchani.yaml b/config/experiments/alaninedipeptide_torchani.yaml new file mode 100644 index 000000000..4232841d3 --- /dev/null +++ b/config/experiments/alaninedipeptide_torchani.yaml @@ -0,0 +1,57 @@ +# @package _global_ + +defaults: + - override /env: alaninedipeptide + - override /gflownet: trajectorybalance + - override /proxy: molecule_torchani + - override /logger: wandb + +# Environment +env: + length_traj: 10 + policy_encoding_dim_per_angle: 10 + n_comp: 5 + vonmises_min_concentration: 4 + reward_func: boltzmann + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: 100 + lr: 0.00001 + z_dim: 16 + lr_z_mult: 1000 + n_train_steps: 40000 + lr_decay_period: 1000000 + policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + backward: + type: mlp + n_hid: 512 + n_layers: 5 + shared_weights: False + checkpoint: backward + +# WandB +logger: + lightweight: True + project_name: "gflownet" + tags: + - gflownet + - continuous + - molecule + test: + period: 5000 + n: 10000 + checkpoints: + period: 5000 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/molecule/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/proxy/molecule_rf.yaml b/config/proxy/molecule_rf.yaml new file mode 100644 index 000000000..1283c09e7 --- /dev/null +++ b/config/proxy/molecule_rf.yaml @@ -0,0 +1,4 @@ +_target_: gflownet.proxy.molecule.RFMoleculeEnergy + +path_to_model: './data/random_forest_reward_100.pkl' +url_to_model: 'https://drive.google.com/uc?id=1OpQNC8WWIsMh8K4olfSaQRFlj3emYThF' \ No newline at end of file diff --git a/config/proxy/molecule_torchani.yaml b/config/proxy/molecule_torchani.yaml new file mode 100644 index 000000000..290c2e798 --- /dev/null +++ b/config/proxy/molecule_torchani.yaml @@ -0,0 +1 @@ +_target_: gflownet.proxy.molecule.TorchANIMoleculeEnergy \ No newline at end of file From 1d095be57b64c68f3e982368a9e82971c6c0c7af Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 17 May 2023 12:09:36 -0400 Subject: [PATCH 018/100] updated proxy defaults --- gflownet/proxy/molecule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index d661b464d..5baa0b494 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -52,9 +52,9 @@ def __deepcopy__(self, memo): class TorchANIMoleculeEnergy(Proxy): def __init__( self, - model: str = "ANI1x", - use_ensemble: bool = False, - batch_size: int = 8, + model: str = "ANI2x", + use_ensemble: bool = True, + batch_size: int = 128, **kwargs, ): """ From 879dd38ed634ce35b7a0eb426ac6a7c9cc16a34f Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 17 May 2023 17:08:20 -0400 Subject: [PATCH 019/100] overwritten deepcopy --- gflownet/proxy/molecule.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 5baa0b494..cfb67087e 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -83,6 +83,7 @@ def __init__( self.model = TORCHANI_MODELS[model]( periodic_table_index=True, model_index=None if use_ensemble else 0 ).to(self.device) + self.conformer = None def setup(self, env=None): self.conformer = env.conformer # deepcopy(env.conformer) @@ -126,3 +127,12 @@ def __call__(self, states: Iterable) -> Tensor: energies.append(self.model((elements_batch, coordinates_batch)).energies) return torch.cat(energies).float() + + def __deepcopy__(self, memo): + cls = self.__class__ + new_obj = cls.__new__(cls) + new_obj.batch_size = self.batch_size + new_obj.min = self.min + new_obj.model = self.model + new_obj.conformer = self.conformer + return new_obj From c2897b751c329669b161cca346aa498122758101 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 17 May 2023 20:07:32 -0400 Subject: [PATCH 020/100] optional batching --- gflownet/proxy/molecule.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index cfb67087e..6029fb394 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -1,6 +1,6 @@ import pickle from copy import deepcopy -from typing import Iterable, List +from typing import Iterable, List, Optional import numpy as np import torch @@ -54,7 +54,7 @@ def __init__( self, model: str = "ANI2x", use_ensemble: bool = True, - batch_size: int = 128, + batch_size: Optional[int] = None, **kwargs, ): """ @@ -67,7 +67,7 @@ def __init__( Whether to use whole ensemble of the models for prediction or only the first one. batch_size : int - Batch size for TorchANI. + Batch size for TorchANI. If none, will process all states as a single batch. """ super().__init__(**kwargs) @@ -119,14 +119,18 @@ def __call__(self, states: Iterable) -> Tensor: elements = LongTensor(np.array(elements)).to(self.device) coordinates = FloatTensor(np.array(coordinates)).to(self.device) - energies = [] - for elements_batch, coordinates_batch in zip( - torch.split(elements, self.batch_size), - torch.split(coordinates, self.batch_size), - ): - energies.append(self.model((elements_batch, coordinates_batch)).energies) - - return torch.cat(energies).float() + if self.batch_size is not None: + energies = [] + for elements_batch, coordinates_batch in zip( + torch.split(elements, self.batch_size), + torch.split(coordinates, self.batch_size), + ): + energies.append(self.model((elements_batch, coordinates_batch)).energies) + energies = torch.cat(energies).float() + else: + energies = self.model((elements, coordinates)).energies.float() + + return energies def __deepcopy__(self, memo): cls = self.__class__ From e51b8f065d863c35e88a3488898ea3bbadab7300 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 17 May 2023 20:40:01 -0400 Subject: [PATCH 021/100] scaled energy --- gflownet/proxy/molecule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 6029fb394..1df2c650e 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -72,7 +72,7 @@ def __init__( super().__init__(**kwargs) self.batch_size = batch_size - self.min = -500 + self.min = -5 if TORCHANI_MODELS.get(model) is None: raise ValueError( @@ -130,7 +130,7 @@ def __call__(self, states: Iterable) -> Tensor: else: energies = self.model((elements, coordinates)).energies.float() - return energies + return energies / 100 def __deepcopy__(self, memo): cls = self.__class__ From 14658279f5db8c5fe58d76f0bfda041904920a6f Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 18 May 2023 09:59:07 -0400 Subject: [PATCH 022/100] updated config --- config/proxy/molecule_torchani.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/config/proxy/molecule_torchani.yaml b/config/proxy/molecule_torchani.yaml index 290c2e798..e96f6e4f1 100644 --- a/config/proxy/molecule_torchani.yaml +++ b/config/proxy/molecule_torchani.yaml @@ -1 +1,4 @@ -_target_: gflownet.proxy.molecule.TorchANIMoleculeEnergy \ No newline at end of file +_target_: gflownet.proxy.molecule.TorchANIMoleculeEnergy + +model: ANI2x +use_ensemble: True \ No newline at end of file From 4b8816ae75b264a4dd88f405a45f5ca9fe4537a2 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 18 May 2023 11:28:16 -0400 Subject: [PATCH 023/100] fixed docstring --- gflownet/proxy/molecule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 1df2c650e..01f42dd2e 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -104,7 +104,7 @@ def __call__(self, states: Iterable) -> Tensor: Returns ---- energies : tensor - Torch with dimensionality (batch_size,), containing energies + Torch tensor with dimensionality (batch_size,), containing energies predicted by a TorchANI model (in Hartree). """ elements = [] From 9b4fa153a39f94a4862cf440e1f13559a30dd314 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 31 May 2023 15:03:37 -0400 Subject: [PATCH 024/100] energy divider as an argument --- gflownet/proxy/molecule.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 01f42dd2e..4c8e9ffe6 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -55,6 +55,7 @@ def __init__( model: str = "ANI2x", use_ensemble: bool = True, batch_size: Optional[int] = None, + divider: float = 100.0, **kwargs, ): """ @@ -68,10 +69,15 @@ def __init__( batch_size : int Batch size for TorchANI. If none, will process all states as a single batch. + + divider : float + The value by which the output of TorchANI will be divided. Necessary for Boltzmann + reward function with high betas, for which the values can explode without division. """ super().__init__(**kwargs) self.batch_size = batch_size + self.divider = divider self.min = -5 if TORCHANI_MODELS.get(model) is None: @@ -125,12 +131,14 @@ def __call__(self, states: Iterable) -> Tensor: torch.split(elements, self.batch_size), torch.split(coordinates, self.batch_size), ): - energies.append(self.model((elements_batch, coordinates_batch)).energies) + energies.append( + self.model((elements_batch, coordinates_batch)).energies + ) energies = torch.cat(energies).float() else: energies = self.model((elements, coordinates)).energies.float() - return energies / 100 + return energies / self.divider def __deepcopy__(self, memo): cls = self.__class__ From feee48ba07dd923fb5a4b0dce17843ccfcd7bc0a Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 31 May 2023 15:03:59 -0400 Subject: [PATCH 025/100] removed unused import --- gflownet/proxy/molecule.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 4c8e9ffe6..236ea76be 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -1,5 +1,4 @@ import pickle -from copy import deepcopy from typing import Iterable, List, Optional import numpy as np @@ -92,7 +91,7 @@ def __init__( self.conformer = None def setup(self, env=None): - self.conformer = env.conformer # deepcopy(env.conformer) + self.conformer = env.conformer def _sync_conformer_with_state(self, state: List): for idx, ta in enumerate(self.conformer.freely_rotatable_tas): From 004f905ca5b32ba54f5ec1890395fb3977e4f1ef Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 1 Jun 2023 15:05:08 -0400 Subject: [PATCH 026/100] added aromatic bond type --- gflownet/utils/molecule/constants.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/molecule/constants.py b/gflownet/utils/molecule/constants.py index ce963c3a1..39b5dabf8 100644 --- a/gflownet/utils/molecule/constants.py +++ b/gflownet/utils/molecule/constants.py @@ -14,7 +14,9 @@ ad_atom_types = ("H", "C", "N", "O") atom_degrees = tuple(range(1, 7)) atom_hybridizations = tuple(list(Chem.rdchem.HybridizationType.names.values())) -bond_types = tuple([Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE]) +bond_types = tuple( + [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.AROMATIC] +) # SMILES strings ad_smiles = "CC(C(=O)NC)NC(=O)C" From 6a908deb21aef46b1bbf823a99d191d8cbedd6fd Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Mon, 5 Jun 2023 10:57:43 -0400 Subject: [PATCH 027/100] XTB proxy --- gflownet/proxy/molecule.py | 61 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 236ea76be..0b9bc147b 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -2,10 +2,15 @@ from typing import Iterable, List, Optional import numpy as np +import numpy.typing as npt import torch import torchani from sklearn.ensemble import RandomForestRegressor from torch import FloatTensor, LongTensor, Tensor +from wurlitzer import pipes +from xtb.interface import Calculator, XTBException +from xtb.libxtb import VERBOSITY_MUTED +from xtb.utils import get_method from gflownet.proxy.base import Proxy from gflownet.utils.common import download_file_if_not_exists @@ -48,6 +53,62 @@ def __deepcopy__(self, memo): return new_obj +class XTBMoleculeEnergy(Proxy): + def __init__(self, method: str = "gfn-ff", **kwargs): + super().__init__(**kwargs) + + self.method = get_method(method) + if self.method is None: + raise ValueError(f"Unrecognized XTB method: {method}.") + self.min = -1000 + self.max = 1000 + self.conformer = None + + def setup(self, env=None): + self.conformer = env.conformer + + def _sync_conformer_with_state(self, state: List): + for idx, ta in enumerate(self.conformer.freely_rotatable_tas): + self.conformer.set_torsion_angle(ta, state[idx]) + return self.conformer + + def __call__(self, states: Iterable) -> Tensor: + elements = [] + coordinates = [] + + for st in states: + conf = self._sync_conformer_with_state(st) + + elements.append(conf.get_atomic_numbers()) + coordinates.append(conf.get_atom_positions()) + + # todo: probably make it parallel with mpi + energies = [self.get_energy(c, e) for (c, e) in zip(coordinates, elements)] + + return torch.tensor(energies, dtype=self.float, device=self.device) + + def get_energy( + self, + atom_positions: npt.NDArray[np.float32], + atomic_numbers: npt.NDArray[np.int64], + ) -> float: + """ + Compute energy of a molecule defined by atom_positions and atomic_numbers + """ + with pipes(): + calc = Calculator(self.method, atomic_numbers, atom_positions) + calc.set_verbosity(VERBOSITY_MUTED) + try: + energy = calc.singlepoint().get_energy() + + if np.isnan(energy): + return self.max + + return energy + except XTBException: + return self.max + + class TorchANIMoleculeEnergy(Proxy): def __init__( self, From 032759be08eead81943a4c0c10c18b2cd6f9c5f1 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Mon, 5 Jun 2023 15:31:26 -0400 Subject: [PATCH 028/100] black --- gflownet/utils/molecule/constants.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/molecule/constants.py b/gflownet/utils/molecule/constants.py index 39b5dabf8..fb1b45756 100644 --- a/gflownet/utils/molecule/constants.py +++ b/gflownet/utils/molecule/constants.py @@ -15,7 +15,11 @@ atom_degrees = tuple(range(1, 7)) atom_hybridizations = tuple(list(Chem.rdchem.HybridizationType.names.values())) bond_types = tuple( - [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.AROMATIC] + [ + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.AROMATIC, + ] ) # SMILES strings From 541e316fb9d2c24c72c9e03257e6691100ce2f74 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 6 Jun 2023 14:10:39 -0400 Subject: [PATCH 029/100] XTB using command line interface instead of Python API --- gflownet/proxy/molecule.py | 58 ++++++++---------- gflownet/utils/molecule/xtb.py | 108 +++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 34 deletions(-) create mode 100644 gflownet/utils/molecule/xtb.py diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 0b9bc147b..edc9a78ee 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -1,19 +1,18 @@ import pickle +from pathlib import Path +from tempfile import TemporaryDirectory from typing import Iterable, List, Optional import numpy as np -import numpy.typing as npt import torch import torchani +from rdkit.Chem.rdmolfiles import MolToXYZFile from sklearn.ensemble import RandomForestRegressor from torch import FloatTensor, LongTensor, Tensor -from wurlitzer import pipes -from xtb.interface import Calculator, XTBException -from xtb.libxtb import VERBOSITY_MUTED -from xtb.utils import get_method from gflownet.proxy.base import Proxy from gflownet.utils.common import download_file_if_not_exists +from gflownet.utils.molecule.xtb import run_gfn_xtb TORCHANI_MODELS = { "ANI1x": torchani.models.ANI1x, @@ -54,14 +53,12 @@ def __deepcopy__(self, memo): class XTBMoleculeEnergy(Proxy): - def __init__(self, method: str = "gfn-ff", **kwargs): + def __init__(self, method: str = "gfnff", **kwargs): super().__init__(**kwargs) - self.method = get_method(method) - if self.method is None: - raise ValueError(f"Unrecognized XTB method: {method}.") - self.min = -1000 - self.max = 1000 + self.method = method + self.min = -5 + self.max = 0 self.conformer = None def setup(self, env=None): @@ -73,40 +70,33 @@ def _sync_conformer_with_state(self, state: List): return self.conformer def __call__(self, states: Iterable) -> Tensor: - elements = [] - coordinates = [] + directories = [] for st in states: conf = self._sync_conformer_with_state(st) - - elements.append(conf.get_atomic_numbers()) - coordinates.append(conf.get_atom_positions()) + directory = TemporaryDirectory() + directories.append(directory) + MolToXYZFile(conf.rdk_mol, str(Path(directory.name) / "input.xyz")) # todo: probably make it parallel with mpi - energies = [self.get_energy(c, e) for (c, e) in zip(coordinates, elements)] + energies = [self.get_energy(d, "input.xyz") for d in directories] + + for directory in directories: + directory.cleanup() return torch.tensor(energies, dtype=self.float, device=self.device) def get_energy( self, - atom_positions: npt.NDArray[np.float32], - atomic_numbers: npt.NDArray[np.int64], + directory: TemporaryDirectory, + file_name: str ) -> float: - """ - Compute energy of a molecule defined by atom_positions and atomic_numbers - """ - with pipes(): - calc = Calculator(self.method, atomic_numbers, atom_positions) - calc.set_verbosity(VERBOSITY_MUTED) - try: - energy = calc.singlepoint().get_energy() - - if np.isnan(energy): - return self.max - - return energy - except XTBException: - return self.max + energy = run_gfn_xtb(directory.name, file_name, gfn_version=self.method) + + if np.isnan(energy): + return self.max + + return energy class TorchANIMoleculeEnergy(Proxy): diff --git a/gflownet/utils/molecule/xtb.py b/gflownet/utils/molecule/xtb.py new file mode 100644 index 000000000..941277f52 --- /dev/null +++ b/gflownet/utils/molecule/xtb.py @@ -0,0 +1,108 @@ +import contextlib +import os +import re +import subprocess +import warnings + +import numpy as np + + +def _get_energy(file): + normal_termination = False + with open(file) as f: + for l in f: + if "TOTAL ENERGY" in l: + energy = float(re.search(r"[+-]?(?:\d*\.)?\d+", l).group()) + if "normal termination of xtb" in l: + normal_termination = True + if normal_termination: + return energy + else: + return np.nan + + +def run_gfn_xtb( + filepath, + filename, + gfn_version="gfnff", + opt=False, + gfn_xtb_config: str = None, + remove_scratch=False, +): + """ + Runs GFN_XTB/FF given a directory and either a coord file or all coord files will be run + + :param filepath: Directory containing the coord file + :param filename: if given, the specific coord file to run + :param gfn_version: GFN_xtb version (default is 2) + :param opt: optimization or single point (default is opt) + :param gfn_xtb_config: additional xtb config (default is None) + :param remove_scratch: remove xtb files + :return: + """ + xyz_file = os.path.join(filepath, filename) + + # optimization vs single point + if opt: + opt = "--opt" + else: + opt = "" + + # cd to filepath + starting_dir = os.getcwd() + os.chdir(filepath) + + file_name = str(xyz_file.split(".")[0]) + cmd = "xtb --{} {} {} {}".format( + str(gfn_version), xyz_file, opt, str(gfn_xtb_config or "") + ) + + # run XTB + with open(file_name + ".out", "w") as fd: + subprocess.run(cmd, shell=True, stdout=fd, stderr=subprocess.STDOUT) + + # check XTB results + if os.path.isfile(os.path.join(filepath, "NOT_CONVERGED")): + # optimization not converged + warnings.warn( + "xtb --{} for {} is not converged, using last optimized step instead; proceed with caution".format( + str(gfn_version), file_name + ) + ) + + # remove files + if remove_scratch: + os.remove(os.path.join(filepath, "NOT_CONVERGED")) + os.remove(os.path.join(filepath, "xtblast.xyz")) + os.remove(os.path.join(filepath, file_name + ".out")) + energy = np.nan + + elif opt and not os.path.isfile(os.path.join(filepath, "xtbopt.xyz")): + # other abnormal optimization convergence + warnings.warn( + "xtb --{} for {} abnormal termination, likely scf issues, using initial geometry instead; proceed with caution".format( + str(gfn_version), file_name + ) + ) + if remove_scratch: + os.remove(os.path.join(filepath, file_name + ".out")) + energy = np.nan + + else: + # normal convergence + # get energy + energy = _get_energy(file_name + ".out") + if remove_scratch: + with contextlib.suppress(FileNotFoundError): + os.remove(os.path.join(filepath, file_name + ".out")) + os.remove(os.path.join(filepath, "gfnff_charges")) + os.remove(os.path.join(filepath, "gfnff_adjacency")) + os.remove(os.path.join(filepath, 'gfnff_topo')) + os.remove(os.path.join(filepath, "xtbopt.log")) + os.remove(os.path.join(filepath, "xtbopt.xyz")) + os.remove(os.path.join(filepath, "xtbtopo.mol")) + os.remove(os.path.join(filepath, "wbo")) + os.remove(os.path.join(filepath, "charges")) + os.remove(os.path.join(filepath, "xtbrestart")) + os.chdir(starting_dir) + return energy From fe6cecc9ebe8ed21c037e25630c2c4f72a2722d6 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 8 Jun 2023 10:13:07 -0400 Subject: [PATCH 030/100] cleaning scratch by default --- gflownet/utils/molecule/xtb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/utils/molecule/xtb.py b/gflownet/utils/molecule/xtb.py index 941277f52..969fb6daa 100644 --- a/gflownet/utils/molecule/xtb.py +++ b/gflownet/utils/molecule/xtb.py @@ -27,7 +27,7 @@ def run_gfn_xtb( gfn_version="gfnff", opt=False, gfn_xtb_config: str = None, - remove_scratch=False, + remove_scratch=True, ): """ Runs GFN_XTB/FF given a directory and either a coord file or all coord files will be run From eaccf0156cd27e42729d2735cbe14e136a9b2144 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 9 Jun 2023 10:13:49 -0400 Subject: [PATCH 031/100] tblite --- config/experiments/alaninedipeptide_xtb.yaml | 57 +++++++++++++++++++ config/proxy/molecule_xtb.yaml | 1 + gflownet/proxy/molecule.py | 58 ++++++++++---------- gflownet/utils/molecule/xtb.py | 2 +- main.py | 8 +++ setup_conformer_conda.sh | 31 +++++++++++ 6 files changed, 126 insertions(+), 31 deletions(-) create mode 100644 config/experiments/alaninedipeptide_xtb.yaml create mode 100644 config/proxy/molecule_xtb.yaml create mode 100644 setup_conformer_conda.sh diff --git a/config/experiments/alaninedipeptide_xtb.yaml b/config/experiments/alaninedipeptide_xtb.yaml new file mode 100644 index 000000000..981cd7dc0 --- /dev/null +++ b/config/experiments/alaninedipeptide_xtb.yaml @@ -0,0 +1,57 @@ +# @package _global_ + +defaults: + - override /env: alaninedipeptide + - override /gflownet: trajectorybalance + - override /proxy: molecule_xtb + - override /logger: wandb + +# Environment +env: + length_traj: 10 + policy_encoding_dim_per_angle: 10 + n_comp: 5 + vonmises_min_concentration: 4 + reward_func: boltzmann + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: 100 + lr: 0.00001 + z_dim: 16 + lr_z_mult: 1000 + n_train_steps: 40000 + lr_decay_period: 1000000 + policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + backward: + type: mlp + n_hid: 512 + n_layers: 5 + shared_weights: False + checkpoint: backward + +# WandB +logger: + lightweight: True + project_name: "gflownet" + tags: + - gflownet + - continuous + - molecule + test: + period: 5000 + n: 10000 + checkpoints: + period: 5000 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/molecule/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/proxy/molecule_xtb.yaml b/config/proxy/molecule_xtb.yaml new file mode 100644 index 000000000..e7f26328c --- /dev/null +++ b/config/proxy/molecule_xtb.yaml @@ -0,0 +1 @@ +_target_: gflownet.proxy.molecule.XTBMoleculeEnergy diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index edc9a78ee..085cf7bef 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -1,18 +1,15 @@ +from tblite.interface import Calculator, Structure import pickle -from pathlib import Path -from tempfile import TemporaryDirectory from typing import Iterable, List, Optional - +import ray import numpy as np import torch import torchani -from rdkit.Chem.rdmolfiles import MolToXYZFile from sklearn.ensemble import RandomForestRegressor from torch import FloatTensor, LongTensor, Tensor from gflownet.proxy.base import Proxy from gflownet.utils.common import download_file_if_not_exists -from gflownet.utils.molecule.xtb import run_gfn_xtb TORCHANI_MODELS = { "ANI1x": torchani.models.ANI1x, @@ -52,11 +49,24 @@ def __deepcopy__(self, memo): return new_obj +@ray.remote +def _get_energy(numbers, positions): + calc = Calculator("GFN2-xTB", numbers, positions) + res = calc.singlepoint() + return res.get("energy").item() + + +def _chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + class XTBMoleculeEnergy(Proxy): - def __init__(self, method: str = "gfnff", **kwargs): + def __init__(self, batch_size=100, **kwargs): super().__init__(**kwargs) - self.method = method + self.batch_size = batch_size self.min = -5 self.max = 0 self.conformer = None @@ -69,35 +79,23 @@ def _sync_conformer_with_state(self, state: List): self.conformer.set_torsion_angle(ta, state[idx]) return self.conformer - def __call__(self, states: Iterable) -> Tensor: - directories = [] + def __call__(self, states: List) -> Tensor: + energies = [] - for st in states: - conf = self._sync_conformer_with_state(st) - directory = TemporaryDirectory() - directories.append(directory) - MolToXYZFile(conf.rdk_mol, str(Path(directory.name) / "input.xyz")) + for batch in _chunks(states, self.batch_size): + structures = [] - # todo: probably make it parallel with mpi - energies = [self.get_energy(d, "input.xyz") for d in directories] + for state in batch: + conf = self._sync_conformer_with_state(state) + structures.append( + (conf.get_atomic_numbers(), conf.get_atom_positions()) + ) - for directory in directories: - directory.cleanup() + tasks = [_get_energy.remote(s[0], s[1]) for s in structures] + energies.extend(ray.get(tasks)) return torch.tensor(energies, dtype=self.float, device=self.device) - def get_energy( - self, - directory: TemporaryDirectory, - file_name: str - ) -> float: - energy = run_gfn_xtb(directory.name, file_name, gfn_version=self.method) - - if np.isnan(energy): - return self.max - - return energy - class TorchANIMoleculeEnergy(Proxy): def __init__( diff --git a/gflownet/utils/molecule/xtb.py b/gflownet/utils/molecule/xtb.py index 969fb6daa..7fd05901e 100644 --- a/gflownet/utils/molecule/xtb.py +++ b/gflownet/utils/molecule/xtb.py @@ -97,7 +97,7 @@ def run_gfn_xtb( os.remove(os.path.join(filepath, file_name + ".out")) os.remove(os.path.join(filepath, "gfnff_charges")) os.remove(os.path.join(filepath, "gfnff_adjacency")) - os.remove(os.path.join(filepath, 'gfnff_topo')) + os.remove(os.path.join(filepath, "gfnff_topo")) os.remove(os.path.join(filepath, "xtbopt.log")) os.remove(os.path.join(filepath, "xtbopt.xyz")) os.remove(os.path.join(filepath, "xtbtopo.mol")) diff --git a/main.py b/main.py index d73a6d95d..ca14df7c0 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,14 @@ """ Runnable script with hydra capabilities """ + +# This is a hotfix for tblite (used for the conformer generation) not +# importing correctly unless it is being imported first. +try: + from tblite import interface +except: + pass + import os import pickle import random diff --git a/setup_conformer_conda.sh b/setup_conformer_conda.sh new file mode 100644 index 000000000..8c371213c --- /dev/null +++ b/setup_conformer_conda.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Installs the conda environment with the name passed as first argument. +# We need conda to install tblite for conformer generation. +# +# Arguments +# $1: Environment name +# +module --force purge +module load cuda/11.7 + +conda create -n $1 python=3.8 +conda activate $1 + +conda install mamba -n base -c conda-forge + +mamba install tblite -c conda-forge +mamba install tblite-python -c conda-forge + +# Update pip +python -m pip install --upgrade pip +# Install PyTorch family +python -m pip install torch torchvision torchaudio +python -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu117.html +# Install DGL (see https://www.dgl.ai/pages/start.html) +python -m pip install dgl -f https://data.dgl.ai/wheels/cu117/repo.html +# Requirements to run +python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani pytorch3d rdkit +# Conditional requirements +python -m pip install wandb matplotlib plotly pymatgen gdown +# Dev packages +# python -m pip install black flake8 isort pylint ipdb jupyter pytest pytest-repeat From 91ef7f84b3b16a385b36d776748757fd5a8b200f Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 9 Jun 2023 16:23:38 -0400 Subject: [PATCH 032/100] conversion --- gflownet/proxy/molecule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 085cf7bef..b9fc62c46 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -51,7 +51,7 @@ def __deepcopy__(self, memo): @ray.remote def _get_energy(numbers, positions): - calc = Calculator("GFN2-xTB", numbers, positions) + calc = Calculator("GFN2-xTB", numbers, positions * 1.8897259886) res = calc.singlepoint() return res.get("energy").item() From 5e29fbf9efb18808e3858c100b7f4b5104c4aad2 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 09:23:06 -0400 Subject: [PATCH 033/100] rearranged conformer configs and environments --- config/env/{ => conformers}/alaninedipeptide.yaml | 2 +- config/experiments/alaninedipeptide_torchani.yaml | 4 ++-- config/experiments/alaninedipeptide_xtb.yaml | 4 ++-- config/experiments/icml23/alaninedipeptide.yaml | 4 ++-- config/proxy/{ => conformers}/molecule_rf.yaml | 0 config/proxy/{ => conformers}/molecule_torchani.yaml | 0 config/proxy/{ => conformers}/molecule_xtb.yaml | 0 gflownet/envs/{ => conformers}/alaninedipeptide.py | 0 8 files changed, 7 insertions(+), 7 deletions(-) rename config/env/{ => conformers}/alaninedipeptide.yaml (91%) rename config/proxy/{ => conformers}/molecule_rf.yaml (100%) rename config/proxy/{ => conformers}/molecule_torchani.yaml (100%) rename config/proxy/{ => conformers}/molecule_xtb.yaml (100%) rename gflownet/envs/{ => conformers}/alaninedipeptide.py (100%) diff --git a/config/env/alaninedipeptide.yaml b/config/env/conformers/alaninedipeptide.yaml similarity index 91% rename from config/env/alaninedipeptide.yaml rename to config/env/conformers/alaninedipeptide.yaml index 71be8f7f6..12947087a 100644 --- a/config/env/alaninedipeptide.yaml +++ b/config/env/conformers/alaninedipeptide.yaml @@ -1,7 +1,7 @@ defaults: - base -_target_: gflownet.envs.alaninedipeptide.AlanineDipeptide +_target_: gflownet.envs.conformers.alaninedipeptide.AlanineDipeptide path_to_dataset: './data/alanine_dipeptide_conformers_1.npy' url_to_dataset: 'https://drive.google.com/uc?id=1r1KRGcpBhR3xaS8yt2i64dfMnJGgNj4C' diff --git a/config/experiments/alaninedipeptide_torchani.yaml b/config/experiments/alaninedipeptide_torchani.yaml index 4232841d3..56fe307fd 100644 --- a/config/experiments/alaninedipeptide_torchani.yaml +++ b/config/experiments/alaninedipeptide_torchani.yaml @@ -1,9 +1,9 @@ # @package _global_ defaults: - - override /env: alaninedipeptide + - override /env: conformers/alaninedipeptide - override /gflownet: trajectorybalance - - override /proxy: molecule_torchani + - override /proxy: conformers/molecule_torchani - override /logger: wandb # Environment diff --git a/config/experiments/alaninedipeptide_xtb.yaml b/config/experiments/alaninedipeptide_xtb.yaml index 981cd7dc0..8c95f9d1c 100644 --- a/config/experiments/alaninedipeptide_xtb.yaml +++ b/config/experiments/alaninedipeptide_xtb.yaml @@ -1,9 +1,9 @@ # @package _global_ defaults: - - override /env: alaninedipeptide + - override /env: conformers/alaninedipeptide - override /gflownet: trajectorybalance - - override /proxy: molecule_xtb + - override /proxy: conformers/molecule_xtb - override /logger: wandb # Environment diff --git a/config/experiments/icml23/alaninedipeptide.yaml b/config/experiments/icml23/alaninedipeptide.yaml index 597fb33e6..7d59ffea1 100644 --- a/config/experiments/icml23/alaninedipeptide.yaml +++ b/config/experiments/icml23/alaninedipeptide.yaml @@ -1,9 +1,9 @@ # @package _global_ defaults: - - override /env: alaninedipeptide + - override /env: conformers/alaninedipeptide - override /gflownet: trajectorybalance - - override /proxy: molecule_rf + - override /proxy: conformers/molecule_rf - override /logger: wandb - override /user: sasha diff --git a/config/proxy/molecule_rf.yaml b/config/proxy/conformers/molecule_rf.yaml similarity index 100% rename from config/proxy/molecule_rf.yaml rename to config/proxy/conformers/molecule_rf.yaml diff --git a/config/proxy/molecule_torchani.yaml b/config/proxy/conformers/molecule_torchani.yaml similarity index 100% rename from config/proxy/molecule_torchani.yaml rename to config/proxy/conformers/molecule_torchani.yaml diff --git a/config/proxy/molecule_xtb.yaml b/config/proxy/conformers/molecule_xtb.yaml similarity index 100% rename from config/proxy/molecule_xtb.yaml rename to config/proxy/conformers/molecule_xtb.yaml diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/conformers/alaninedipeptide.py similarity index 100% rename from gflownet/envs/alaninedipeptide.py rename to gflownet/envs/conformers/alaninedipeptide.py From 3deaeda7d7d0b813eb78d452d14e176d0d14457e Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 09:29:07 -0400 Subject: [PATCH 034/100] added wurlitzer (to supress XTB output) --- gflownet/proxy/molecule.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index af3831b70..1669a48a2 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -1,12 +1,14 @@ -from tblite.interface import Calculator, Structure import pickle from typing import Iterable, List, Optional + import ray import numpy as np import torch import torchani from sklearn.ensemble import RandomForestRegressor +from tblite.interface import Calculator, Structure from torch import FloatTensor, LongTensor, Tensor +from wurlitzer import pipes from gflownet.proxy.base import Proxy from gflownet.utils.common import download_file_if_not_exists @@ -51,9 +53,12 @@ def __deepcopy__(self, memo): @ray.remote def _get_energy(numbers, positions): - calc = Calculator("GFN2-xTB", numbers, positions * 1.8897259886) - res = calc.singlepoint() - return res.get("energy").item() + with pipes(): + calc = Calculator("GFN2-xTB", numbers, positions * 1.8897259886) + res = calc.singlepoint() + energy = res.get("energy").item() + + return energy def _chunks(lst, n): From 7475f65a43da8e8e047af60236bc498b2edc4ff6 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 10:03:04 -0400 Subject: [PATCH 035/100] conformer environment --- config/env/conformers/conformer.yaml | 40 ++++++++ .../alaninedipeptide_torchani.yaml | 57 ------------ ...aninedipeptide_xtb.yaml => conformer.yaml} | 2 +- gflownet/envs/conformers/__init__.py | 0 gflownet/envs/conformers/conformer.py | 91 +++++++++++++++++++ 5 files changed, 132 insertions(+), 58 deletions(-) create mode 100644 config/env/conformers/conformer.yaml delete mode 100644 config/experiments/alaninedipeptide_torchani.yaml rename config/experiments/{alaninedipeptide_xtb.yaml => conformer.yaml} (95%) create mode 100644 gflownet/envs/conformers/__init__.py create mode 100644 gflownet/envs/conformers/conformer.py diff --git a/config/env/conformers/conformer.yaml b/config/env/conformers/conformer.yaml new file mode 100644 index 000000000..ee8afce89 --- /dev/null +++ b/config/env/conformers/conformer.yaml @@ -0,0 +1,40 @@ +defaults: + - base + +_target_: gflownet.envs.conformers.conformer.Conformer + +# alanine dipeptide +smiles: 'CC(C(=O)NC)NC(=O)C' +torsion_angles: [[0, 1, 2, 3], [0, 1, 6, 7]] +path_to_dataset: './data/alanine_dipeptide_conformers_1.npy' +url_to_dataset: 'https://drive.google.com/uc?id=1r1KRGcpBhR3xaS8yt2i64dfMnJGgNj4C' + +# ibuprofen +# smiles: 'CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O' +# torsion_angles: [[2, 1, 3, 4], [1, 3, 4, 5]] +# path_to_dataset: './data/ibuprofen_conformers_1.npy' +# url_to_dataset: 'https://drive.google.com/uc?id=1wRvaiQ0H2gP3gNqRfpwXRJJ4pF70ulyf' + +id: conformer +policy_encoding_dim_per_angle: null +# Fixed length of trajectories +length_traj: 10 +vonmises_min_concentration: 1e-3 +# Parameters of the fixed policy output distribution +n_comp: 3 +fixed_distribution: + vonmises_mean: 0.0 + vonmises_concentration: 0.5 +# Parameters of the random policy output distribution +random_distribution: + vonmises_mean: 0.0 + vonmises_concentration: 0.001 +# Buffer +buffer: + data_path: null + train: null + test: + type: grid + n: 1000 + output_csv: conformer_test.csv + output_pkl: conformer_test.pkl diff --git a/config/experiments/alaninedipeptide_torchani.yaml b/config/experiments/alaninedipeptide_torchani.yaml deleted file mode 100644 index 56fe307fd..000000000 --- a/config/experiments/alaninedipeptide_torchani.yaml +++ /dev/null @@ -1,57 +0,0 @@ -# @package _global_ - -defaults: - - override /env: conformers/alaninedipeptide - - override /gflownet: trajectorybalance - - override /proxy: conformers/molecule_torchani - - override /logger: wandb - -# Environment -env: - length_traj: 10 - policy_encoding_dim_per_angle: 10 - n_comp: 5 - vonmises_min_concentration: 4 - reward_func: boltzmann - -# GFlowNet hyperparameters -gflownet: - random_action_prob: 0.1 - optimizer: - batch_size: 100 - lr: 0.00001 - z_dim: 16 - lr_z_mult: 1000 - n_train_steps: 40000 - lr_decay_period: 1000000 - policy: - forward: - type: mlp - n_hid: 512 - n_layers: 5 - checkpoint: forward - backward: - type: mlp - n_hid: 512 - n_layers: 5 - shared_weights: False - checkpoint: backward - -# WandB -logger: - lightweight: True - project_name: "gflownet" - tags: - - gflownet - - continuous - - molecule - test: - period: 5000 - n: 10000 - checkpoints: - period: 5000 - -# Hydra -hydra: - run: - dir: ${user.logdir.root}/molecule/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/experiments/alaninedipeptide_xtb.yaml b/config/experiments/conformer.yaml similarity index 95% rename from config/experiments/alaninedipeptide_xtb.yaml rename to config/experiments/conformer.yaml index 8c95f9d1c..11f2f1737 100644 --- a/config/experiments/alaninedipeptide_xtb.yaml +++ b/config/experiments/conformer.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /env: conformers/alaninedipeptide + - override /env: conformers/conformer - override /gflownet: trajectorybalance - override /proxy: conformers/molecule_xtb - override /logger: wandb diff --git a/gflownet/envs/conformers/__init__.py b/gflownet/envs/conformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py new file mode 100644 index 000000000..9ada4e723 --- /dev/null +++ b/gflownet/envs/conformers/conformer.py @@ -0,0 +1,91 @@ +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt +import torch +from torchtyping import TensorType + +from gflownet.envs.ctorus import ContinuousTorus +from gflownet.utils.molecule.datasets import AtomPositionsDataset +from gflownet.utils.molecule.rdkit_conformer import RDKitConformer + + +class Conformer(ContinuousTorus): + """ + Extension of continuous torus to conformer generation. Based on AlanineDipeptide, + but accepts any molecule (defined by SMILES, freely rotatable torsion angles, and + path to dataset containing sample conformers. + """ + def __init__( + self, + smiles: str, + torsion_angles: List[List[int]], + path_to_dataset: str, + url_to_dataset: str, + **kwargs, + ): + self.atom_positions_dataset = AtomPositionsDataset( + path_to_dataset, url_to_dataset + ) + atom_positions = self.atom_positions_dataset.sample() + self.conformer = RDKitConformer( + atom_positions, smiles, torsion_angles + ) + n_dim = len(self.conformer.freely_rotatable_tas) + super().__init__(n_dim=n_dim, **kwargs) + self.sync_conformer_with_state() + + def sync_conformer_with_state(self, state: List = None): + if state is None: + state = self.state + for idx, ta in enumerate(self.conformer.freely_rotatable_tas): + self.conformer.set_torsion_angle(ta, state[idx]) + return self.conformer + + def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: + """ + Prepares a batch of states in torch "GFlowNet format" for the oracle. + """ + device = states.device + if device == torch.device("cpu"): + np_states = states.numpy() + else: + np_states = states.cpu().numpy() + return np_states[:, :-1] + + def statebatch2proxy(self, states: List[List]) -> npt.NDArray: + """ + Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where + each state is a row of length n_dim with an angle in radians. The n_actions + item is removed. + """ + return np.array(states)[:, :-1] + + def statetorch2oracle( + self, states: TensorType["batch", "state_dim"] + ) -> List[Tuple[npt.NDArray, npt.NDArray]]: + """ + Prepares a batch of states in torch "GFlowNet format" for the oracle. + """ + device = states.device + if device == torch.device("cpu"): + np_states = states.numpy() + else: + np_states = states.cpu().numpy() + result = self.statebatch2oracle(np_states) + return result + + def statebatch2oracle( + self, states: List[List] + ) -> List[Tuple[npt.NDArray, npt.NDArray]]: + """ + Prepares a batch of states in "GFlowNet format" for the oracle: a list of + tuples, where first element in the tuple is numpy array of atom positions of + shape [num_atoms, 3] and the second element is numpy array of atomic numbers of + shape [num_atoms, ] + """ + states_oracle = [] + for st in states: + conf = self.sync_conformer_with_state(st) + states_oracle.append((conf.get_atom_positions(), conf.get_atomic_numbers())) + return states_oracle From f15cb4c20a58a1d72b8b62f5b6423842bec834bd Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 11:37:30 -0400 Subject: [PATCH 036/100] deepcopy workaround for XTB proxy --- gflownet/proxy/molecule.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/molecule.py index 1669a48a2..81a0c1909 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/molecule.py @@ -101,6 +101,12 @@ def __call__(self, states: List) -> Tensor: return torch.tensor(energies, dtype=self.float, device=self.device) + def __deepcopy__(self, memo): + cls = self.__class__ + new_obj = cls.__new__(cls) + new_obj.__dict__.update(self.__dict__) + return new_obj + class TorchANIMoleculeEnergy(Proxy): def __init__( From caf6537596d2e028525528a7060d3dc103e9f448 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 11:44:09 -0400 Subject: [PATCH 037/100] simplified cpu casting --- gflownet/envs/conformers/conformer.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 9ada4e723..637be5e0d 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -2,7 +2,6 @@ import numpy as np import numpy.typing as npt -import torch from torchtyping import TensorType from gflownet.envs.ctorus import ContinuousTorus @@ -46,12 +45,7 @@ def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDAr """ Prepares a batch of states in torch "GFlowNet format" for the oracle. """ - device = states.device - if device == torch.device("cpu"): - np_states = states.numpy() - else: - np_states = states.cpu().numpy() - return np_states[:, :-1] + return states.cpu().numpy()[:, :-1] def statebatch2proxy(self, states: List[List]) -> npt.NDArray: """ @@ -67,13 +61,7 @@ def statetorch2oracle( """ Prepares a batch of states in torch "GFlowNet format" for the oracle. """ - device = states.device - if device == torch.device("cpu"): - np_states = states.numpy() - else: - np_states = states.cpu().numpy() - result = self.statebatch2oracle(np_states) - return result + return self.statebatch2oracle(states.cpu().numpy()) def statebatch2oracle( self, states: List[List] From 8353087b90bb429c90bc9de1238e71589632328a Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 12:36:11 -0400 Subject: [PATCH 038/100] conformer molecules split into separate files --- config/experiments/conformer.yaml | 2 +- .../experiments/icml23/alaninedipeptide.yaml | 2 +- .../proxy/conformers/molecule_torchani.yaml | 4 - config/proxy/conformers/molecule_xtb.yaml | 1 - .../conformers/{molecule_rf.yaml => rf.yaml} | 4 +- config/proxy/conformers/torchani.yaml | 4 + config/proxy/conformers/xtb.yaml | 1 + gflownet/proxy/conformers/__init__.py | 0 gflownet/proxy/conformers/rf.py | 39 ++++++++ .../{molecule.py => conformers/torchani.py} | 94 ------------------- gflownet/proxy/conformers/xtb.py | 66 +++++++++++++ 11 files changed, 114 insertions(+), 103 deletions(-) delete mode 100644 config/proxy/conformers/molecule_torchani.yaml delete mode 100644 config/proxy/conformers/molecule_xtb.yaml rename config/proxy/conformers/{molecule_rf.yaml => rf.yaml} (61%) create mode 100644 config/proxy/conformers/torchani.yaml create mode 100644 config/proxy/conformers/xtb.yaml create mode 100644 gflownet/proxy/conformers/__init__.py create mode 100644 gflownet/proxy/conformers/rf.py rename gflownet/proxy/{molecule.py => conformers/torchani.py} (55%) create mode 100644 gflownet/proxy/conformers/xtb.py diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index 11f2f1737..df9219845 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -3,7 +3,7 @@ defaults: - override /env: conformers/conformer - override /gflownet: trajectorybalance - - override /proxy: conformers/molecule_xtb + - override /proxy: conformers/xtb - override /logger: wandb # Environment diff --git a/config/experiments/icml23/alaninedipeptide.yaml b/config/experiments/icml23/alaninedipeptide.yaml index 7d59ffea1..d19b5c611 100644 --- a/config/experiments/icml23/alaninedipeptide.yaml +++ b/config/experiments/icml23/alaninedipeptide.yaml @@ -3,7 +3,7 @@ defaults: - override /env: conformers/alaninedipeptide - override /gflownet: trajectorybalance - - override /proxy: conformers/molecule_rf + - override /proxy: conformers/rf - override /logger: wandb - override /user: sasha diff --git a/config/proxy/conformers/molecule_torchani.yaml b/config/proxy/conformers/molecule_torchani.yaml deleted file mode 100644 index e96f6e4f1..000000000 --- a/config/proxy/conformers/molecule_torchani.yaml +++ /dev/null @@ -1,4 +0,0 @@ -_target_: gflownet.proxy.molecule.TorchANIMoleculeEnergy - -model: ANI2x -use_ensemble: True \ No newline at end of file diff --git a/config/proxy/conformers/molecule_xtb.yaml b/config/proxy/conformers/molecule_xtb.yaml deleted file mode 100644 index e7f26328c..000000000 --- a/config/proxy/conformers/molecule_xtb.yaml +++ /dev/null @@ -1 +0,0 @@ -_target_: gflownet.proxy.molecule.XTBMoleculeEnergy diff --git a/config/proxy/conformers/molecule_rf.yaml b/config/proxy/conformers/rf.yaml similarity index 61% rename from config/proxy/conformers/molecule_rf.yaml rename to config/proxy/conformers/rf.yaml index 1283c09e7..78bcd098a 100644 --- a/config/proxy/conformers/molecule_rf.yaml +++ b/config/proxy/conformers/rf.yaml @@ -1,4 +1,4 @@ -_target_: gflownet.proxy.molecule.RFMoleculeEnergy +_target_: gflownet.proxy.conformers.rf.RFMoleculeEnergy path_to_model: './data/random_forest_reward_100.pkl' -url_to_model: 'https://drive.google.com/uc?id=1OpQNC8WWIsMh8K4olfSaQRFlj3emYThF' \ No newline at end of file +url_to_model: 'https://drive.google.com/uc?id=1OpQNC8WWIsMh8K4olfSaQRFlj3emYThF' diff --git a/config/proxy/conformers/torchani.yaml b/config/proxy/conformers/torchani.yaml new file mode 100644 index 000000000..dced8a348 --- /dev/null +++ b/config/proxy/conformers/torchani.yaml @@ -0,0 +1,4 @@ +_target_: gflownet.proxy.conformers.torchani.TorchANIMoleculeEnergy + +model: ANI2x +use_ensemble: True diff --git a/config/proxy/conformers/xtb.yaml b/config/proxy/conformers/xtb.yaml new file mode 100644 index 000000000..6d35712a0 --- /dev/null +++ b/config/proxy/conformers/xtb.yaml @@ -0,0 +1 @@ +_target_: gflownet.proxy.conformers.xtb.XTBMoleculeEnergy diff --git a/gflownet/proxy/conformers/__init__.py b/gflownet/proxy/conformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gflownet/proxy/conformers/rf.py b/gflownet/proxy/conformers/rf.py new file mode 100644 index 000000000..957dd9ab3 --- /dev/null +++ b/gflownet/proxy/conformers/rf.py @@ -0,0 +1,39 @@ +import pickle + +import numpy as np +import torch +from sklearn.ensemble import RandomForestRegressor + +from gflownet.proxy.base import Proxy +from gflownet.utils.common import download_file_if_not_exists + + +class RFMoleculeEnergy(Proxy): + def __init__(self, path_to_model, url_to_model, **kwargs): + super().__init__(**kwargs) + self.min = -np.log(105) + path_to_model = download_file_if_not_exists(path_to_model, url_to_model) + if path_to_model is not None: + with open(path_to_model, "rb") as inp: + self.model = pickle.load(inp) + + def set_n_dim(self, n_dim): + # self.n_dim is never used in this env, + # this is just to make molecule env work with htorus + self.n_dim = n_dim + + def __call__(self, states_proxy): + # output of the model is exp(-energy) / 100 + x = states_proxy % (2 * np.pi) + rewards = -np.log(self.model.predict(x) * 100) + return torch.tensor( + rewards, + dtype=self.float, + device=self.device, + ) + + def __deepcopy__(self, memo): + cls = self.__class__ + new_obj = cls.__new__(cls) + new_obj.__dict__.update(self.__dict__) + return new_obj diff --git a/gflownet/proxy/molecule.py b/gflownet/proxy/conformers/torchani.py similarity index 55% rename from gflownet/proxy/molecule.py rename to gflownet/proxy/conformers/torchani.py index 81a0c1909..de4dd41f8 100644 --- a/gflownet/proxy/molecule.py +++ b/gflownet/proxy/conformers/torchani.py @@ -1,17 +1,11 @@ -import pickle from typing import Iterable, List, Optional -import ray import numpy as np import torch import torchani -from sklearn.ensemble import RandomForestRegressor -from tblite.interface import Calculator, Structure from torch import FloatTensor, LongTensor, Tensor -from wurlitzer import pipes from gflownet.proxy.base import Proxy -from gflownet.utils.common import download_file_if_not_exists TORCHANI_MODELS = { "ANI1x": torchani.models.ANI1x, @@ -20,94 +14,6 @@ } -class RFMoleculeEnergy(Proxy): - def __init__(self, path_to_model, url_to_model, **kwargs): - super().__init__(**kwargs) - self.min = -np.log(105) - path_to_model = download_file_if_not_exists(path_to_model, url_to_model) - if path_to_model is not None: - with open(path_to_model, "rb") as inp: - self.model = pickle.load(inp) - - def set_n_dim(self, n_dim): - # self.n_dim is never used in this env, - # this is just to make molecule env work with htorus - self.n_dim = n_dim - - def __call__(self, states_proxy): - # output of the model is exp(-energy) / 100 - x = states_proxy % (2 * np.pi) - rewards = -np.log(self.model.predict(x) * 100) - return torch.tensor( - rewards, - dtype=self.float, - device=self.device, - ) - - def __deepcopy__(self, memo): - cls = self.__class__ - new_obj = cls.__new__(cls) - new_obj.__dict__.update(self.__dict__) - return new_obj - - -@ray.remote -def _get_energy(numbers, positions): - with pipes(): - calc = Calculator("GFN2-xTB", numbers, positions * 1.8897259886) - res = calc.singlepoint() - energy = res.get("energy").item() - - return energy - - -def _chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - -class XTBMoleculeEnergy(Proxy): - def __init__(self, batch_size=100, **kwargs): - super().__init__(**kwargs) - - self.batch_size = batch_size - self.min = -5 - self.max = 0 - self.conformer = None - - def setup(self, env=None): - self.conformer = env.conformer - - def _sync_conformer_with_state(self, state: List): - for idx, ta in enumerate(self.conformer.freely_rotatable_tas): - self.conformer.set_torsion_angle(ta, state[idx]) - return self.conformer - - def __call__(self, states: List) -> Tensor: - energies = [] - - for batch in _chunks(states, self.batch_size): - structures = [] - - for state in batch: - conf = self._sync_conformer_with_state(state) - structures.append( - (conf.get_atomic_numbers(), conf.get_atom_positions()) - ) - - tasks = [_get_energy.remote(s[0], s[1]) for s in structures] - energies.extend(ray.get(tasks)) - - return torch.tensor(energies, dtype=self.float, device=self.device) - - def __deepcopy__(self, memo): - cls = self.__class__ - new_obj = cls.__new__(cls) - new_obj.__dict__.update(self.__dict__) - return new_obj - - class TorchANIMoleculeEnergy(Proxy): def __init__( self, diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py new file mode 100644 index 000000000..a14dfac88 --- /dev/null +++ b/gflownet/proxy/conformers/xtb.py @@ -0,0 +1,66 @@ +from typing import List + +import ray +import torch +from tblite.interface import Calculator, Structure +from torch import Tensor +from wurlitzer import pipes + +from gflownet.proxy.base import Proxy + + +@ray.remote +def _get_energy(numbers, positions): + with pipes(): + calc = Calculator("GFN2-xTB", numbers, positions * 1.8897259886) + res = calc.singlepoint() + energy = res.get("energy").item() + + return energy + + +def _chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +class XTBMoleculeEnergy(Proxy): + def __init__(self, batch_size=100, **kwargs): + super().__init__(**kwargs) + + self.batch_size = batch_size + self.min = -5 + self.max = 0 + self.conformer = None + + def setup(self, env=None): + self.conformer = env.conformer + + def _sync_conformer_with_state(self, state: List): + for idx, ta in enumerate(self.conformer.freely_rotatable_tas): + self.conformer.set_torsion_angle(ta, state[idx]) + return self.conformer + + def __call__(self, states: List) -> Tensor: + energies = [] + + for batch in _chunks(states, self.batch_size): + structures = [] + + for state in batch: + conf = self._sync_conformer_with_state(state) + structures.append( + (conf.get_atomic_numbers(), conf.get_atom_positions()) + ) + + tasks = [_get_energy.remote(s[0], s[1]) for s in structures] + energies.extend(ray.get(tasks)) + + return torch.tensor(energies, dtype=self.float, device=self.device) + + def __deepcopy__(self, memo): + cls = self.__class__ + new_obj = cls.__new__(cls) + new_obj.__dict__.update(self.__dict__) + return new_obj From c4b97d8b6bb013b63fab12fd799b63345a20ff97 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 15:55:00 -0400 Subject: [PATCH 039/100] conformer's state2proxy returning 3D coordinates --- gflownet/envs/base.py | 20 +++++++ gflownet/envs/conformers/conformer.py | 85 +++++++++++++++------------ gflownet/gflownet.py | 6 +- gflownet/proxy/conformers/xtb.py | 25 +------- 4 files changed, 73 insertions(+), 63 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index be536e76b..dc886ef36 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -364,6 +364,26 @@ def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: """ return np.array(states) + def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]: + """ + Prepares a batch of states in "GFlowNet format" for the proxy. Typically, + this will be the same as the statebatch2proxy, but in cases in which proxy + input is already processed (e.g., conformers, with list of torsion angles + converted to 3D positions of atoms), this can be overwritten to preserve KDE. + """ + return self.statebatch2proxy(states) + + def statetorch2kde( + self, states: TensorType["batch_size", "state_dim"] + ) -> TensorType["batch_size", "state_proxy_dim"]: + """ + Prepares a batch of states in torch "GFlowNet format" for the KDE. Typically, + this will be the same as the statetorch2proxy, but in cases in which proxy + input is already processed (e.g., conformers, with list of torsion angles + converted to 3D positions of atoms), this can be overwritten to preserve KDE. + """ + return self.statetorch2proxy(states) + def policy2state(self, state_policy: List) -> List: """ Converts the model (e.g. one-hot encoding) version of a state given as diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 637be5e0d..233a010e0 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -1,4 +1,5 @@ -from typing import List, Tuple +import copy +from typing import List import numpy as np import numpy.typing as npt @@ -15,6 +16,7 @@ class Conformer(ContinuousTorus): but accepts any molecule (defined by SMILES, freely rotatable torsion angles, and path to dataset containing sample conformers. """ + def __init__( self, smiles: str, @@ -27,11 +29,14 @@ def __init__( path_to_dataset, url_to_dataset ) atom_positions = self.atom_positions_dataset.sample() - self.conformer = RDKitConformer( - atom_positions, smiles, torsion_angles - ) - n_dim = len(self.conformer.freely_rotatable_tas) - super().__init__(n_dim=n_dim, **kwargs) + self.conformer = RDKitConformer(atom_positions, smiles, torsion_angles) + + # Conversions + self.statebatch2oracle = self.statebatch2proxy + self.statetorch2oracle = self.statetorch2proxy + + super().__init__(n_dim=len(self.conformer.freely_rotatable_tas), **kwargs) + self.sync_conformer_with_state() def sync_conformer_with_state(self, state: List = None): @@ -41,39 +46,47 @@ def sync_conformer_with_state(self, state: List = None): self.conformer.set_torsion_angle(ta, state[idx]) return self.conformer - def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: + def statebatch2proxy(self, states: List[List]) -> List[npt.NDArray]: """ - Prepares a batch of states in torch "GFlowNet format" for the oracle. + Returns a list of proxy states, each being a numpy array with dimensionality + (n_atoms, 4), in which first the column encodes atomic number, and the last + three columns encode atom positions. """ - return states.cpu().numpy()[:, :-1] + states_proxy = [] + for st in states: + conf = self.sync_conformer_with_state(st) + states_proxy.append( + np.concatenate( + [ + conf.get_atomic_numbers()[..., np.newaxis], + conf.get_atom_positions(), + ], + axis=1, + ) + ) + return states_proxy - def statebatch2proxy(self, states: List[List]) -> npt.NDArray: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where - each state is a row of length n_dim with an angle in radians. The n_actions - item is removed. - """ + def statetorch2proxy( + self, states: TensorType["batch", "state_dim"] + ) -> List[npt.NDArray]: + return self.statebatch2proxy(states.cpu().numpy()) + + def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]: return np.array(states)[:, :-1] - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> List[Tuple[npt.NDArray, npt.NDArray]]: - """ - Prepares a batch of states in torch "GFlowNet format" for the oracle. - """ - return self.statebatch2oracle(states.cpu().numpy()) + def statetorch2kde( + self, states: TensorType["batch_size", "state_dim"] + ) -> TensorType["batch_size", "state_proxy_dim"]: + return states.cpu().numpy()[:, :-1] - def statebatch2oracle( - self, states: List[List] - ) -> List[Tuple[npt.NDArray, npt.NDArray]]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle: a list of - tuples, where first element in the tuple is numpy array of atom positions of - shape [num_atoms, 3] and the second element is numpy array of atomic numbers of - shape [num_atoms, ] - """ - states_oracle = [] - for st in states: - conf = self.sync_conformer_with_state(st) - states_oracle.append((conf.get_atom_positions(), conf.get_atomic_numbers())) - return states_oracle + def __deepcopy__(self, memo): + cls = self.__class__ + new_instance = cls.__new__(cls) + + for attr_name, attr_value in self.__dict__.items(): + if attr_name != "conformer": + setattr(new_instance, attr_name, copy.copy(attr_value)) + + new_instance.conformer = self.conformer + + return new_instance diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index a16fb8374..7353749c6 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -747,8 +747,8 @@ def test(self, **plot_kwargs): log_density_pred = np.log(density_pred + 1e-8) elif self.continuous: # TODO make it work with conditional env - x_sampled = torch2np(self.env.statebatch2proxy(x_sampled)) - x_tt = torch2np(self.env.statebatch2proxy(x_tt)) + x_sampled = torch2np(self.env.statebatch2kde(x_sampled)) + x_tt = torch2np(self.env.statebatch2kde(x_tt)) kde_pred = self.env.fit_kde( x_sampled, kernel=self.logger.test.kde.kernel, @@ -762,7 +762,7 @@ def test(self, **plot_kwargs): x_from_reward = self.env.sample_from_reward( n_samples=self.logger.test.n ) - x_from_reward = torch2np(self.env.statetorch2proxy(x_from_reward)) + x_from_reward = torch2np(self.env.statetorch2kde(x_from_reward)) # Fit KDE with samples from reward kde_true = self.env.fit_kde( x_from_reward, diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index a14dfac88..3f15ab7d3 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -32,35 +32,12 @@ def __init__(self, batch_size=100, **kwargs): self.batch_size = batch_size self.min = -5 self.max = 0 - self.conformer = None - - def setup(self, env=None): - self.conformer = env.conformer - - def _sync_conformer_with_state(self, state: List): - for idx, ta in enumerate(self.conformer.freely_rotatable_tas): - self.conformer.set_torsion_angle(ta, state[idx]) - return self.conformer def __call__(self, states: List) -> Tensor: energies = [] for batch in _chunks(states, self.batch_size): - structures = [] - - for state in batch: - conf = self._sync_conformer_with_state(state) - structures.append( - (conf.get_atomic_numbers(), conf.get_atom_positions()) - ) - - tasks = [_get_energy.remote(s[0], s[1]) for s in structures] + tasks = [_get_energy.remote(s[:, 0], s[:, 1:]) for s in batch] energies.extend(ray.get(tasks)) return torch.tensor(energies, dtype=self.float, device=self.device) - - def __deepcopy__(self, memo): - cls = self.__class__ - new_obj = cls.__new__(cls) - new_obj.__dict__.update(self.__dict__) - return new_obj From f6d693bcd2259267779d793208d042e42678d039 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 16:07:31 -0400 Subject: [PATCH 040/100] fixed typo --- gflownet/envs/conformers/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 233a010e0..58fe088a2 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -14,7 +14,7 @@ class Conformer(ContinuousTorus): """ Extension of continuous torus to conformer generation. Based on AlanineDipeptide, but accepts any molecule (defined by SMILES, freely rotatable torsion angles, and - path to dataset containing sample conformers. + path to dataset containing sample conformers). """ def __init__( From d5839b13fa889a3ea3b9e86089cf39b2f48ccd37 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 18:48:47 -0400 Subject: [PATCH 041/100] typo --- gflownet/envs/conformers/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 58fe088a2..e29fd20f4 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -49,7 +49,7 @@ def sync_conformer_with_state(self, state: List = None): def statebatch2proxy(self, states: List[List]) -> List[npt.NDArray]: """ Returns a list of proxy states, each being a numpy array with dimensionality - (n_atoms, 4), in which first the column encodes atomic number, and the last + (n_atoms, 4), in which the first column encodes atomic number, and the last three columns encode atom positions. """ states_proxy = [] From 20a8eacab38506d1bce6ac88bb2b99157333fdf9 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 18:56:43 -0400 Subject: [PATCH 042/100] atom_positions_dataset removed from attributes --- gflownet/envs/conformers/conformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index e29fd20f4..eef44cc57 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -25,10 +25,10 @@ def __init__( url_to_dataset: str, **kwargs, ): - self.atom_positions_dataset = AtomPositionsDataset( + atom_positions_dataset = AtomPositionsDataset( path_to_dataset, url_to_dataset ) - atom_positions = self.atom_positions_dataset.sample() + atom_positions = atom_positions_dataset.sample() self.conformer = RDKitConformer(atom_positions, smiles, torsion_angles) # Conversions From 9695baf2836329a43fc1e4b53fb0848e00a2bdf9 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 19:24:09 -0400 Subject: [PATCH 043/100] dataset file containing multiple molecules --- config/env/conformers/conformer.yaml | 13 +++---------- gflownet/envs/conformers/conformer.py | 6 +++--- gflownet/utils/molecule/datasets.py | 10 ++++++++-- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/config/env/conformers/conformer.yaml b/config/env/conformers/conformer.yaml index ee8afce89..29e97d101 100644 --- a/config/env/conformers/conformer.yaml +++ b/config/env/conformers/conformer.yaml @@ -3,17 +3,10 @@ defaults: _target_: gflownet.envs.conformers.conformer.Conformer -# alanine dipeptide -smiles: 'CC(C(=O)NC)NC(=O)C' -torsion_angles: [[0, 1, 2, 3], [0, 1, 6, 7]] -path_to_dataset: './data/alanine_dipeptide_conformers_1.npy' -url_to_dataset: 'https://drive.google.com/uc?id=1r1KRGcpBhR3xaS8yt2i64dfMnJGgNj4C' +smiles: 'CC(C(=O)NC)NC(=O)C' # alanine dipeptide -# ibuprofen -# smiles: 'CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O' -# torsion_angles: [[2, 1, 3, 4], [1, 3, 4, 5]] -# path_to_dataset: './data/ibuprofen_conformers_1.npy' -# url_to_dataset: 'https://drive.google.com/uc?id=1wRvaiQ0H2gP3gNqRfpwXRJJ4pF70ulyf' +path_to_dataset: './data/conformers.npy' +url_to_dataset: 'https://drive.google.com/uc?id=1MefikIedDjUtUJtzCHwXhZtpXHJX3ImA' id: conformer policy_encoding_dim_per_angle: null diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index eef44cc57..f2cc72d31 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -20,15 +20,15 @@ class Conformer(ContinuousTorus): def __init__( self, smiles: str, - torsion_angles: List[List[int]], path_to_dataset: str, url_to_dataset: str, **kwargs, ): atom_positions_dataset = AtomPositionsDataset( - path_to_dataset, url_to_dataset + smiles, path_to_dataset, url_to_dataset ) - atom_positions = atom_positions_dataset.sample() + atom_positions = atom_positions_dataset.first() + torsion_angles = atom_positions_dataset.torsion_angles self.conformer = RDKitConformer(atom_positions, smiles, torsion_angles) # Conversions diff --git a/gflownet/utils/molecule/datasets.py b/gflownet/utils/molecule/datasets.py index f87fe7b22..e9d1595a2 100644 --- a/gflownet/utils/molecule/datasets.py +++ b/gflownet/utils/molecule/datasets.py @@ -7,9 +7,12 @@ class AtomPositionsDataset: - def __init__(self, path_to_data, url_to_data): + def __init__(self, smiles: str, path_to_data: str, url_to_data: str): path_to_data = download_file_if_not_exists(path_to_data, url_to_data) - self.positions = np.load(path_to_data) + conformers = np.load(path_to_data, allow_pickle=True).item() + + self.positions = conformers[smiles]['conformers'] + self.torsion_angles = conformers[smiles]['torsion_angles'] def __getitem__(self, i): return self.positions[i] @@ -21,6 +24,9 @@ def sample(self, size=None): idx = np.random.randint(0, len(self), size=size) return self.positions[idx] + def first(self): + return self[0] + class ConformersDataset: def __init__(self, path_to_data, url_to_data): From 6223cef611662079254e9f28adedfc5555ac6873 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 19:49:45 -0400 Subject: [PATCH 044/100] subtracting constant energy term --- config/env/conformers/conformer.yaml | 4 +++- gflownet/envs/conformers/conformer.py | 8 +++++++- gflownet/proxy/conformers/xtb.py | 14 ++++++++++---- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/config/env/conformers/conformer.yaml b/config/env/conformers/conformer.yaml index 29e97d101..b0e7f21ee 100644 --- a/config/env/conformers/conformer.yaml +++ b/config/env/conformers/conformer.yaml @@ -3,7 +3,9 @@ defaults: _target_: gflownet.envs.conformers.conformer.Conformer -smiles: 'CC(C(=O)NC)NC(=O)C' # alanine dipeptide +# smiles: 'CC(C(=O)NC)NC(=O)C' # alanine dipeptide +# smiles: 'CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O' # ibuprofen +smiles: 'O=C(c1ccc2n1CCC2C(=O)O)c3ccccc3' # ketorolac path_to_dataset: './data/conformers.npy' url_to_dataset: 'https://drive.google.com/uc?id=1MefikIedDjUtUJtzCHwXhZtpXHJX3ImA' diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index f2cc72d31..fca3c4649 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -3,9 +3,11 @@ import numpy as np import numpy.typing as npt +import ray from torchtyping import TensorType from gflownet.envs.ctorus import ContinuousTorus +from gflownet.proxy.conformers.xtb import get_energy from gflownet.utils.molecule.datasets import AtomPositionsDataset from gflownet.utils.molecule.rdkit_conformer import RDKitConformer @@ -31,6 +33,11 @@ def __init__( torsion_angles = atom_positions_dataset.torsion_angles self.conformer = RDKitConformer(atom_positions, smiles, torsion_angles) + tasks = [] + for positions in atom_positions_dataset.positions: + tasks.append(get_energy.remote(self.conformer.get_atomic_numbers(), positions)) + self.max_energy = max(ray.get(tasks)) + # Conversions self.statebatch2oracle = self.statebatch2proxy self.statetorch2oracle = self.statetorch2proxy @@ -38,7 +45,6 @@ def __init__( super().__init__(n_dim=len(self.conformer.freely_rotatable_tas), **kwargs) self.sync_conformer_with_state() - def sync_conformer_with_state(self, state: List = None): if state is None: state = self.state diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index 3f15ab7d3..958252728 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -10,7 +10,7 @@ @ray.remote -def _get_energy(numbers, positions): +def get_energy(numbers, positions): with pipes(): calc = Calculator("GFN2-xTB", numbers, positions * 1.8897259886) res = calc.singlepoint() @@ -30,14 +30,20 @@ def __init__(self, batch_size=100, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size + self.max_energy = 0 self.min = -5 - self.max = 0 def __call__(self, states: List) -> Tensor: energies = [] for batch in _chunks(states, self.batch_size): - tasks = [_get_energy.remote(s[:, 0], s[:, 1:]) for s in batch] + tasks = [get_energy.remote(s[:, 0], s[:, 1:]) for s in batch] energies.extend(ray.get(tasks)) - return torch.tensor(energies, dtype=self.float, device=self.device) + energies = torch.tensor(energies, dtype=self.float, device=self.device) + energies -= self.max_energy + + return energies + + def setup(self, env=None): + self.max_energy = env.max_energy From fb404b1f3b4c928c15ec31bbd4a4e22dda87d925 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 15 Jun 2023 20:06:35 -0400 Subject: [PATCH 045/100] updated conformer conda env --- setup_conformer_conda.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup_conformer_conda.sh b/setup_conformer_conda.sh index 8c371213c..5e66b8f88 100644 --- a/setup_conformer_conda.sh +++ b/setup_conformer_conda.sh @@ -24,7 +24,7 @@ python -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spl # Install DGL (see https://www.dgl.ai/pages/start.html) python -m pip install dgl -f https://data.dgl.ai/wheels/cu117/repo.html # Requirements to run -python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani pytorch3d rdkit +python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani pytorch3d rdkit ray wurlitzer # Conditional requirements python -m pip install wandb matplotlib plotly pymatgen gdown # Dev packages From 15f905f253d1d9a877f770493d8cede55bed8ff9 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 16 Jun 2023 12:24:27 -0400 Subject: [PATCH 046/100] estimating min value for rejection sampling --- gflownet/envs/conformers/conformer.py | 4 +++- gflownet/proxy/conformers/xtb.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index fca3c4649..a670c6ef1 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -36,7 +36,9 @@ def __init__( tasks = [] for positions in atom_positions_dataset.positions: tasks.append(get_energy.remote(self.conformer.get_atomic_numbers(), positions)) - self.max_energy = max(ray.get(tasks)) + energies = ray.get(tasks) + self.max_energy = max(energies) + self.min_energy = min(energies) # Conversions self.statebatch2oracle = self.statebatch2proxy diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index 958252728..2db2b55db 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -26,12 +26,12 @@ def _chunks(lst, n): class XTBMoleculeEnergy(Proxy): - def __init__(self, batch_size=100, **kwargs): + def __init__(self, batch_size=1000, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size self.max_energy = 0 - self.min = -5 + self.min = 0 def __call__(self, states: List) -> Tensor: energies = [] @@ -47,3 +47,4 @@ def __call__(self, states: List) -> Tensor: def setup(self, env=None): self.max_energy = env.max_energy + self.min = env.min_energy - env.max_energy From f6166c92b3866c76a6844c4e68524ebb8011766b Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 16 Jun 2023 12:25:46 -0400 Subject: [PATCH 047/100] black --- gflownet/envs/conformers/conformer.py | 5 ++++- gflownet/utils/molecule/datasets.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index a670c6ef1..b2c5a371c 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -35,7 +35,9 @@ def __init__( tasks = [] for positions in atom_positions_dataset.positions: - tasks.append(get_energy.remote(self.conformer.get_atomic_numbers(), positions)) + tasks.append( + get_energy.remote(self.conformer.get_atomic_numbers(), positions) + ) energies = ray.get(tasks) self.max_energy = max(energies) self.min_energy = min(energies) @@ -47,6 +49,7 @@ def __init__( super().__init__(n_dim=len(self.conformer.freely_rotatable_tas), **kwargs) self.sync_conformer_with_state() + def sync_conformer_with_state(self, state: List = None): if state is None: state = self.state diff --git a/gflownet/utils/molecule/datasets.py b/gflownet/utils/molecule/datasets.py index e9d1595a2..feb49a5cb 100644 --- a/gflownet/utils/molecule/datasets.py +++ b/gflownet/utils/molecule/datasets.py @@ -11,8 +11,8 @@ def __init__(self, smiles: str, path_to_data: str, url_to_data: str): path_to_data = download_file_if_not_exists(path_to_data, url_to_data) conformers = np.load(path_to_data, allow_pickle=True).item() - self.positions = conformers[smiles]['conformers'] - self.torsion_angles = conformers[smiles]['torsion_angles'] + self.positions = conformers[smiles]["conformers"] + self.torsion_angles = conformers[smiles]["torsion_angles"] def __getitem__(self, i): return self.positions[i] From 1c67145e3ba4524e6d99fa1c425b02893a1e265a Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 16 Jun 2023 19:03:13 -0400 Subject: [PATCH 048/100] ray setup for cluster --- main.py | 3 +++ setup_conformer_conda.sh | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index d9d30296d..ed1d51115 100644 --- a/main.py +++ b/main.py @@ -16,12 +16,15 @@ import hydra import pandas as pd +import ray import yaml from omegaconf import DictConfig, OmegaConf @hydra.main(config_path="./config", config_name="main", version_base="1.1") def main(config): + ray.init(num_cpus=12) + # Get current directory and set it as root log dir for Logger cwd = os.getcwd() config.logger.logdir.root = cwd diff --git a/setup_conformer_conda.sh b/setup_conformer_conda.sh index 5e66b8f88..d4a40cb27 100644 --- a/setup_conformer_conda.sh +++ b/setup_conformer_conda.sh @@ -24,7 +24,8 @@ python -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spl # Install DGL (see https://www.dgl.ai/pages/start.html) python -m pip install dgl -f https://data.dgl.ai/wheels/cu117/repo.html # Requirements to run -python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani pytorch3d rdkit ray wurlitzer +python -m pip install ray ray[tune] ray[default] +python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani pytorch3d rdkit wurlitzer # Conditional requirements python -m pip install wandb matplotlib plotly pymatgen gdown # Dev packages From 7a6fa8e9e7742bd69d83fd6f923105627c25c591 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 21 Jun 2023 20:37:15 -0400 Subject: [PATCH 049/100] updated torchani proxy input format --- gflownet/proxy/conformers/torchani.py | 37 ++++++++----------- .../test_torchani.py} | 10 +++-- 2 files changed, 22 insertions(+), 25 deletions(-) rename tests/gflownet/proxy/{test_molecule.py => conformers/test_torchani.py} (64%) diff --git a/gflownet/proxy/conformers/torchani.py b/gflownet/proxy/conformers/torchani.py index de4dd41f8..cfffdf45a 100644 --- a/gflownet/proxy/conformers/torchani.py +++ b/gflownet/proxy/conformers/torchani.py @@ -1,9 +1,8 @@ -from typing import Iterable, List, Optional +from typing import Iterable, Optional -import numpy as np import torch import torchani -from torch import FloatTensor, LongTensor, Tensor +from torch import Tensor from gflownet.proxy.base import Proxy @@ -54,15 +53,6 @@ def __init__( self.model = TORCHANI_MODELS[model]( periodic_table_index=True, model_index=None if use_ensemble else 0 ).to(self.device) - self.conformer = None - - def setup(self, env=None): - self.conformer = env.conformer - - def _sync_conformer_with_state(self, state: List): - for idx, ta in enumerate(self.conformer.freely_rotatable_tas): - self.conformer.set_torsion_angle(ta, state[idx]) - return self.conformer @torch.no_grad() def __call__(self, states: Iterable) -> Tensor: @@ -70,7 +60,9 @@ def __call__(self, states: Iterable) -> Tensor: Args ---- states - An iterable of states in AlanineDipeptide environment format (torsion angles). + An iterable of states in Conformer environment format (tensors with + dimensionality (n_atoms, 4), in which the first column encodes atomic + number, and the last three columns encode atom positions). Returns ---- @@ -82,13 +74,17 @@ def __call__(self, states: Iterable) -> Tensor: coordinates = [] for st in states: - conf = self._sync_conformer_with_state(st) - - elements.append(conf.get_atomic_numbers()) - coordinates.append(conf.get_atom_positions()) - - elements = LongTensor(np.array(elements)).to(self.device) - coordinates = FloatTensor(np.array(coordinates)).to(self.device) + el = st[:, 0] + if not isinstance(el, Tensor): + el = Tensor(el) + co = st[:, 1:] + if not isinstance(co, Tensor): + co = Tensor(co) + elements.append(el) + coordinates.append(co) + + elements = torch.stack(elements).long().to(self.device) + coordinates = torch.stack(coordinates).float().to(self.device) if self.batch_size is not None: energies = [] @@ -111,5 +107,4 @@ def __deepcopy__(self, memo): new_obj.batch_size = self.batch_size new_obj.min = self.min new_obj.model = self.model - new_obj.conformer = self.conformer return new_obj diff --git a/tests/gflownet/proxy/test_molecule.py b/tests/gflownet/proxy/conformers/test_torchani.py similarity index 64% rename from tests/gflownet/proxy/test_molecule.py rename to tests/gflownet/proxy/conformers/test_torchani.py index 07f7a70c3..aceb37172 100644 --- a/tests/gflownet/proxy/test_molecule.py +++ b/tests/gflownet/proxy/conformers/test_torchani.py @@ -2,7 +2,7 @@ import pytest import torch -from gflownet.proxy.molecule import TorchANIMoleculeEnergy +from gflownet.proxy.conformers.torchani import TorchANIMoleculeEnergy from gflownet.utils.molecule.rdkit_conformer import get_dummy_ad_rdkconf @@ -14,8 +14,9 @@ def proxy(): def test__torchani_molecule_energy__predicts_energy_for_a_single_numpy_conformer(proxy): conf = get_dummy_ad_rdkconf() coordinates, elements = conf.get_atom_positions(), conf.get_atomic_numbers() + state = np.concatenate((np.expand_dims(elements, axis=1), coordinates), axis=1) - proxy(elements[np.newaxis, ...], coordinates[np.newaxis, ...]) + assert proxy(state[np.newaxis, ...]).shape == torch.Size([1]) def test__torchani_molecule_energy__predicts_energy_for_a_pytorch_batch(proxy): @@ -23,6 +24,7 @@ def test__torchani_molecule_energy__predicts_energy_for_a_pytorch_batch(proxy): coordinates, elements = conf.get_atom_positions(), conf.get_atomic_numbers() coordinates = torch.Tensor(coordinates).repeat(3, 1, 1) - elements = torch.Tensor(elements).repeat(3, 1) + elements = torch.Tensor(elements).repeat(3, 1).unsqueeze(-1) + state = torch.concat((elements, coordinates), dim=-1) - proxy(elements, coordinates) + assert proxy(state).shape == torch.Size([3]) From bea25acaf0aeedae1f8a29057bd150a95f68d9fe Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 22 Jun 2023 10:02:53 -0400 Subject: [PATCH 050/100] number of jobs in config --- config/experiments/conformer.yaml | 3 +++ config/main.yaml | 2 ++ main.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index df9219845..b4c354e35 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -6,6 +6,9 @@ defaults: - override /proxy: conformers/xtb - override /logger: wandb +# Number of parallel ray jobs (XTB greatly benefits from parallelization) +n_jobs: 12 + # Environment env: length_traj: 10 diff --git a/config/main.yaml b/config/main.yaml index 97417b444..e06b57762 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -14,6 +14,8 @@ float_precision: 32 n_samples: 1000 # Random seeds seed: 0 +# Number of parallel ray jobs +n_jobs: 1 # Hydra config hydra: diff --git a/main.py b/main.py index ed1d51115..36e7b59ed 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,7 @@ @hydra.main(config_path="./config", config_name="main", version_base="1.1") def main(config): - ray.init(num_cpus=12) + ray.init(num_cpus=config.n_jobs) # Get current directory and set it as root log dir for Logger cwd = os.getcwd() From a1b89dcf6790e15b24b51b6bc69cd3fd3c8dd5c4 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 22 Jun 2023 10:39:24 -0400 Subject: [PATCH 051/100] policy class moved to a separate file --- gflownet/gflownet.py | 134 +--------------------------------- gflownet/policy/__init__.py | 0 gflownet/policy/base.py | 141 ++++++++++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 133 deletions(-) create mode 100644 gflownet/policy/__init__.py create mode 100644 gflownet/policy/base.py diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 7353749c6..f416063fd 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -16,11 +16,11 @@ import torch import torch.nn as nn import yaml -from omegaconf import OmegaConf from scipy.special import logsumexp from torch.distributions import Bernoulli, Categorical from tqdm import tqdm +from gflownet.policy.base import Policy from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer from gflownet.utils.common import ( @@ -879,138 +879,6 @@ def log_iter( ) -class Policy: - def __init__(self, config, env, device, float_precision, base=None): - # If config is null, default to uniform - if config is None: - config = OmegaConf.create() - config.type = "uniform" - # Device and float precision - self.device = device - self.float = float_precision - # Input and output dimensions - self.state_dim = env.policy_input_dim - self.fixed_output = torch.tensor(env.fixed_policy_output).to( - dtype=self.float, device=self.device - ) - self.random_output = torch.tensor(env.random_policy_output).to( - dtype=self.float, device=self.device - ) - self.output_dim = len(self.fixed_output) - if "shared_weights" in config: - self.shared_weights = config.shared_weights - else: - self.shared_weights = False - self.base = base - if "n_hid" in config: - self.n_hid = config.n_hid - else: - self.n_hid = None - if "n_layers" in config: - self.n_layers = config.n_layers - else: - self.n_layers = None - if "tail" in config: - self.tail = config.tail - else: - self.tail = [] - if "type" in config: - self.type = config.type - elif self.shared_weights: - self.type = self.base.type - else: - raise "Policy type must be defined if shared_weights is False" - # Instantiate policy - if self.type == "fixed": - self.model = self.fixed_distribution - self.is_model = False - elif self.type == "uniform": - self.model = self.uniform_distribution - self.is_model = False - elif self.type == "mlp": - self.model = self.make_mlp(nn.LeakyReLU()) - self.is_model = True - else: - raise "Policy model type not defined" - if self.is_model: - self.model.to(self.device) - - def __call__(self, states): - return self.model(states) - - def make_mlp(self, activation): - """ - Defines an MLP with no top layer activation - If share_weight == True, - baseModel (the model with which weights are to be shared) must be provided - Args - ---- - layers_dim : list - Dimensionality of each layer - activation : Activation - Activation function - """ - if self.shared_weights == True and self.base is not None: - mlp = nn.Sequential( - self.base.model[:-1], - nn.Linear( - self.base.model[-1].in_features, self.base.model[-1].out_features - ), - ) - return mlp - elif self.shared_weights == False: - layers_dim = ( - [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] - ) - mlp = nn.Sequential( - *( - sum( - [ - [nn.Linear(idim, odim)] - + ([activation] if n < len(layers_dim) - 2 else []) - for n, (idim, odim) in enumerate( - zip(layers_dim, layers_dim[1:]) - ) - ], - [], - ) - + self.tail - ) - ) - return mlp - else: - raise ValueError( - "Base Model must be provided when shared_weights is set to True" - ) - - def fixed_distribution(self, states): - """ - Returns the fixed distribution specified by the environment. - Args: states: tensor - """ - return torch.tile(self.fixed_output, (len(states), 1)).to( - dtype=self.float, device=self.device - ) - - def random_distribution(self, states): - """ - Returns the random distribution specified by the environment. - Args: states: tensor - """ - return torch.tile(self.random_output, (len(states), 1)).to( - dtype=self.float, device=self.device - ) - - def uniform_distribution(self, states): - """ - Return action logits (log probabilities) from a uniform distribution - Args: states: tensor - """ - return torch.ones( - (len(states), self.output_dim), dtype=self.float, device=self.device - ) - - def make_opt(params, logZ, config): """ Set up the optimizer diff --git a/gflownet/policy/__init__.py b/gflownet/policy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py new file mode 100644 index 000000000..17b330ea5 --- /dev/null +++ b/gflownet/policy/base.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +from omegaconf import OmegaConf + + +class Policy: + def __init__(self, config, env, device, float_precision, base=None): + # If config is null, default to uniform + if config is None: + config = OmegaConf.create() + config.type = "uniform" + # Device and float precision + self.device = device + self.float = float_precision + # Input and output dimensions + self.state_dim = env.policy_input_dim + self.fixed_output = torch.tensor(env.fixed_policy_output).to( + dtype=self.float, device=self.device + ) + self.random_output = torch.tensor(env.random_policy_output).to( + dtype=self.float, device=self.device + ) + self.output_dim = len(self.fixed_output) + self.base = base + + self._setup_config(config) + self._instantiate_policy() + + def _setup_config(self, config): + if "shared_weights" in config: + self.shared_weights = config.shared_weights + else: + self.shared_weights = False + if "n_hid" in config: + self.n_hid = config.n_hid + else: + self.n_hid = None + if "n_layers" in config: + self.n_layers = config.n_layers + else: + self.n_layers = None + if "tail" in config: + self.tail = config.tail + else: + self.tail = [] + if "type" in config: + self.type = config.type + elif self.shared_weights: + self.type = self.base.type + else: + raise "Policy type must be defined if shared_weights is False" + + def _instantiate_policy(self): + if self.type == "fixed": + self.model = self.fixed_distribution + self.is_model = False + elif self.type == "uniform": + self.model = self.uniform_distribution + self.is_model = False + elif self.type == "mlp": + self.model = self.make_mlp(nn.LeakyReLU()) + self.is_model = True + else: + raise "Policy model type not defined" + if self.is_model: + self.model.to(self.device) + + def __call__(self, states): + return self.model(states) + + def make_mlp(self, activation): + """ + Defines an MLP with no top layer activation + If share_weight == True, + baseModel (the model with which weights are to be shared) must be provided + Args + ---- + layers_dim : list + Dimensionality of each layer + activation : Activation + Activation function + """ + if self.shared_weights == True and self.base is not None: + mlp = nn.Sequential( + self.base.model[:-1], + nn.Linear( + self.base.model[-1].in_features, self.base.model[-1].out_features + ), + ) + return mlp + elif self.shared_weights == False: + layers_dim = ( + [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] + ) + mlp = nn.Sequential( + *( + sum( + [ + [nn.Linear(idim, odim)] + + ([activation] if n < len(layers_dim) - 2 else []) + for n, (idim, odim) in enumerate( + zip(layers_dim, layers_dim[1:]) + ) + ], + [], + ) + + self.tail + ) + ) + return mlp + else: + raise ValueError( + "Base Model must be provided when shared_weights is set to True" + ) + + def fixed_distribution(self, states): + """ + Returns the fixed distribution specified by the environment. + Args: states: tensor + """ + return torch.tile(self.fixed_output, (len(states), 1)).to( + dtype=self.float, device=self.device + ) + + def random_distribution(self, states): + """ + Returns the random distribution specified by the environment. + Args: states: tensor + """ + return torch.tile(self.random_output, (len(states), 1)).to( + dtype=self.float, device=self.device + ) + + def uniform_distribution(self, states): + """ + Return action logits (log probabilities) from a uniform distribution + Args: states: tensor + """ + return torch.ones( + (len(states), self.output_dim), dtype=self.float, device=self.device + ) From db1a9da58d0ea8d0368750860f7e3698f065d430 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 22 Jun 2023 10:51:16 -0400 Subject: [PATCH 052/100] disabled ray logging --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 36e7b59ed..d4175868c 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,7 @@ @hydra.main(config_path="./config", config_name="main", version_base="1.1") def main(config): - ray.init(num_cpus=config.n_jobs) + ray.init(num_cpus=config.n_jobs, log_to_driver=False) # Get current directory and set it as root log dir for Logger cwd = os.getcwd() From bd820ef5b8cd32fc253ed57fbcacd9acbe0a31dd Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 22 Jun 2023 12:53:06 -0400 Subject: [PATCH 053/100] updated conformer env to return numpy arrays in *2proxy methods --- gflownet/envs/conformers/conformer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index b2c5a371c..9f0e36d84 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -57,7 +57,7 @@ def sync_conformer_with_state(self, state: List = None): self.conformer.set_torsion_angle(ta, state[idx]) return self.conformer - def statebatch2proxy(self, states: List[List]) -> List[npt.NDArray]: + def statebatch2proxy(self, states: List[List]) -> npt.NDArray: """ Returns a list of proxy states, each being a numpy array with dimensionality (n_atoms, 4), in which the first column encodes atomic number, and the last @@ -75,11 +75,9 @@ def statebatch2proxy(self, states: List[List]) -> List[npt.NDArray]: axis=1, ) ) - return states_proxy + return np.array(states_proxy) - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> List[npt.NDArray]: + def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: return self.statebatch2proxy(states.cpu().numpy()) def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]: From 4bfb86b1ad1c6d504dffdc8e966093b856eb16d7 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 10:12:06 -0400 Subject: [PATCH 054/100] XTB proxy renamed to TBLite --- config/experiments/conformer.yaml | 2 +- config/proxy/conformers/tblite.yaml | 1 + config/proxy/conformers/xtb.yaml | 1 - gflownet/envs/conformers/conformer.py | 2 +- gflownet/proxy/conformers/{xtb.py => tblite.py} | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 config/proxy/conformers/tblite.yaml delete mode 100644 config/proxy/conformers/xtb.yaml rename gflownet/proxy/conformers/{xtb.py => tblite.py} (97%) diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index b4c354e35..a3dc89c45 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -3,7 +3,7 @@ defaults: - override /env: conformers/conformer - override /gflownet: trajectorybalance - - override /proxy: conformers/xtb + - override /proxy: conformers/tblite - override /logger: wandb # Number of parallel ray jobs (XTB greatly benefits from parallelization) diff --git a/config/proxy/conformers/tblite.yaml b/config/proxy/conformers/tblite.yaml new file mode 100644 index 000000000..09e4119b0 --- /dev/null +++ b/config/proxy/conformers/tblite.yaml @@ -0,0 +1 @@ +_target_: gflownet.proxy.conformers.tblite.TBLiteMoleculeEnergy diff --git a/config/proxy/conformers/xtb.yaml b/config/proxy/conformers/xtb.yaml deleted file mode 100644 index 6d35712a0..000000000 --- a/config/proxy/conformers/xtb.yaml +++ /dev/null @@ -1 +0,0 @@ -_target_: gflownet.proxy.conformers.xtb.XTBMoleculeEnergy diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 9f0e36d84..a7e2aa4a8 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -7,7 +7,7 @@ from torchtyping import TensorType from gflownet.envs.ctorus import ContinuousTorus -from gflownet.proxy.conformers.xtb import get_energy +from gflownet.proxy.conformers.tblite import get_energy from gflownet.utils.molecule.datasets import AtomPositionsDataset from gflownet.utils.molecule.rdkit_conformer import RDKitConformer diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/tblite.py similarity index 97% rename from gflownet/proxy/conformers/xtb.py rename to gflownet/proxy/conformers/tblite.py index 2db2b55db..956b1d737 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/tblite.py @@ -25,7 +25,7 @@ def _chunks(lst, n): yield lst[i : i + n] -class XTBMoleculeEnergy(Proxy): +class TBLiteMoleculeEnergy(Proxy): def __init__(self, batch_size=1000, **kwargs): super().__init__(**kwargs) From 3d3c80962bc55facc8229ed9ad718c5e30387f79 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 11:35:06 -0400 Subject: [PATCH 055/100] re-added XTB proxy --- config/proxy/conformers/xtb.yaml | 1 + gflownet/proxy/conformers/xtb.py | 76 ++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 config/proxy/conformers/xtb.yaml create mode 100644 gflownet/proxy/conformers/xtb.py diff --git a/config/proxy/conformers/xtb.yaml b/config/proxy/conformers/xtb.yaml new file mode 100644 index 000000000..6d35712a0 --- /dev/null +++ b/config/proxy/conformers/xtb.yaml @@ -0,0 +1 @@ +_target_: gflownet.proxy.conformers.xtb.XTBMoleculeEnergy diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py new file mode 100644 index 000000000..767a628f2 --- /dev/null +++ b/gflownet/proxy/conformers/xtb.py @@ -0,0 +1,76 @@ +# This is a hotfix for tblite (used for the conformer generation) not +# importing correctly unless it is being imported first. +try: + from tblite import interface +except: + pass + +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Iterable + +import numpy as np +import numpy.typing as npt +import ray +import torch +from torch import Tensor +from wurlitzer import pipes + +from gflownet.proxy.base import Proxy +from gflownet.proxy.conformers.tblite import _chunks +from gflownet.utils.molecule.xtb import run_gfn_xtb + + +def _write_xyz_file(elements: npt.NDArray, coordinates: npt.NDArray, file_path: str) -> None: + num_atoms = len(elements) + with open(file_path, 'w') as f: + f.write(str(num_atoms) + '\n') + f.write('\n') + + for i in range(num_atoms): + element = elements[i] + x, y, z = coordinates[i] + line = f"{int(element)} {x:.6f} {y:.6f} {z:.6f}\n" + f.write(line) + + +@ray.remote +def get_energy(numbers, positions, method="gfnff"): + directory = TemporaryDirectory() + file_name = "input.xyz" + + _write_xyz_file(numbers, positions, str(Path(directory.name) / "input.xyz")) + with pipes(): + energy = run_gfn_xtb(directory.name, file_name, gfn_version=method) + directory.cleanup() + + if np.isnan(energy): + return 0.0 + + return energy + + +class XTBMoleculeEnergy(Proxy): + def __init__(self, method: str = "gfnff", batch_size=1000, **kwargs): + super().__init__(**kwargs) + + self.method = method + self.batch_size = batch_size + self.max_energy = 0 + self.min = 0 + + def __call__(self, states: Iterable) -> Tensor: + energies = [] + + for batch in _chunks(states, self.batch_size): + tasks = [get_energy.remote(s[:, 0], s[:, 1:], self.method) for s in batch] + energies.extend(ray.get(tasks)) + + energies = torch.tensor(energies, dtype=self.float, device=self.device) + energies -= self.max_energy + + return energies + + def setup(self, env=None): + self.max_energy = env.max_energy + self.min = env.min_energy - env.max_energy From 16c83af3bfe521d60b16e42a68d613362fb21931 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 13:19:10 -0400 Subject: [PATCH 056/100] constant energy term computation in a proxy instead of environment --- gflownet/envs/conformers/conformer.py | 11 ----------- gflownet/proxy/conformers/tblite.py | 13 +++++++++---- gflownet/proxy/conformers/xtb.py | 11 ++++++++--- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index a7e2aa4a8..0101cc734 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -3,11 +3,9 @@ import numpy as np import numpy.typing as npt -import ray from torchtyping import TensorType from gflownet.envs.ctorus import ContinuousTorus -from gflownet.proxy.conformers.tblite import get_energy from gflownet.utils.molecule.datasets import AtomPositionsDataset from gflownet.utils.molecule.rdkit_conformer import RDKitConformer @@ -33,15 +31,6 @@ def __init__( torsion_angles = atom_positions_dataset.torsion_angles self.conformer = RDKitConformer(atom_positions, smiles, torsion_angles) - tasks = [] - for positions in atom_positions_dataset.positions: - tasks.append( - get_energy.remote(self.conformer.get_atomic_numbers(), positions) - ) - energies = ray.get(tasks) - self.max_energy = max(energies) - self.min_energy = min(energies) - # Conversions self.statebatch2oracle = self.statebatch2proxy self.statetorch2oracle = self.statetorch2proxy diff --git a/gflownet/proxy/conformers/tblite.py b/gflownet/proxy/conformers/tblite.py index 956b1d737..496b4f1bf 100644 --- a/gflownet/proxy/conformers/tblite.py +++ b/gflownet/proxy/conformers/tblite.py @@ -1,8 +1,9 @@ from typing import List +import numpy as np import ray import torch -from tblite.interface import Calculator, Structure +from tblite.interface import Calculator from torch import Tensor from wurlitzer import pipes @@ -26,10 +27,11 @@ def _chunks(lst, n): class TBLiteMoleculeEnergy(Proxy): - def __init__(self, batch_size=1000, **kwargs): + def __init__(self, batch_size=1024, n_samples=5000, **kwargs): super().__init__(**kwargs) self.batch_size = batch_size + self.n_samples = n_samples self.max_energy = 0 self.min = 0 @@ -46,5 +48,8 @@ def __call__(self, states: List) -> Tensor: return energies def setup(self, env=None): - self.max_energy = env.max_energy - self.min = env.min_energy - env.max_energy + states = env.statebatch2proxy(2 * np.pi * np.random.rand(self.n_samples, 3)) + energies = self(states) + + self.max_energy = max(energies) + self.min = min(energies) - self.max_energy diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index 767a628f2..de42c1b8b 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -51,11 +51,12 @@ def get_energy(numbers, positions, method="gfnff"): class XTBMoleculeEnergy(Proxy): - def __init__(self, method: str = "gfnff", batch_size=1000, **kwargs): + def __init__(self, method: str = "gfnff", batch_size=1024, n_samples=5000, **kwargs): super().__init__(**kwargs) self.method = method self.batch_size = batch_size + self.n_samples = n_samples self.max_energy = 0 self.min = 0 @@ -65,6 +66,7 @@ def __call__(self, states: Iterable) -> Tensor: for batch in _chunks(states, self.batch_size): tasks = [get_energy.remote(s[:, 0], s[:, 1:], self.method) for s in batch] energies.extend(ray.get(tasks)) + print(len(energies)) energies = torch.tensor(energies, dtype=self.float, device=self.device) energies -= self.max_energy @@ -72,5 +74,8 @@ def __call__(self, states: Iterable) -> Tensor: return energies def setup(self, env=None): - self.max_energy = env.max_energy - self.min = env.min_energy - env.max_energy + states = env.statebatch2proxy(2 * np.pi * np.random.rand(self.n_samples, 3)) + energies = self(states) + + self.max_energy = max(energies) + self.min = min(energies) - self.max_energy From 4de0a0db5a354b04b1a75547e8abd4f603336831 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 13:23:54 -0400 Subject: [PATCH 057/100] removed print --- gflownet/proxy/conformers/xtb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index de42c1b8b..1bf069f5e 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -66,7 +66,6 @@ def __call__(self, states: Iterable) -> Tensor: for batch in _chunks(states, self.batch_size): tasks = [get_energy.remote(s[:, 0], s[:, 1:], self.method) for s in batch] energies.extend(ray.get(tasks)) - print(len(energies)) energies = torch.tensor(energies, dtype=self.float, device=self.device) energies -= self.max_energy From 75a515253c702bcddffcc40ddc067020a01378b5 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 13:38:07 -0400 Subject: [PATCH 058/100] added constant subtraction in base conformer proxy class --- gflownet/proxy/conformers/base.py | 38 +++++++++++++++++++++++++++ gflownet/proxy/conformers/tblite.py | 21 +++------------ gflownet/proxy/conformers/torchani.py | 24 +++++++---------- gflownet/proxy/conformers/xtb.py | 31 +++++++++------------- 4 files changed, 64 insertions(+), 50 deletions(-) create mode 100644 gflownet/proxy/conformers/base.py diff --git a/gflownet/proxy/conformers/base.py b/gflownet/proxy/conformers/base.py new file mode 100644 index 000000000..fcee38400 --- /dev/null +++ b/gflownet/proxy/conformers/base.py @@ -0,0 +1,38 @@ +from abc import ABC +from typing import Optional + +import numpy as np + +from gflownet.proxy.base import Proxy + + +class MoleculeEnergyBase(Proxy, ABC): + def __init__( + self, + batch_size: Optional[int] = 128, + n_samples: int = 5000, + **kwargs, + ): + """ + Parameters + ---------- + + batch_size : int + Batch size for the underlying model. + + n_samples : int + Number of samples that will be used to estimate minimum and maximum energy. + """ + super().__init__(**kwargs) + + self.batch_size = batch_size + self.n_samples = n_samples + self.max_energy = 0 + self.min = 0 + + def setup(self, env=None): + states = env.statebatch2proxy(2 * np.pi * np.random.rand(self.n_samples, 3)) + energies = self(states) + + self.max_energy = max(energies) + self.min = min(energies) - self.max_energy diff --git a/gflownet/proxy/conformers/tblite.py b/gflownet/proxy/conformers/tblite.py index 496b4f1bf..93bae15ca 100644 --- a/gflownet/proxy/conformers/tblite.py +++ b/gflownet/proxy/conformers/tblite.py @@ -1,13 +1,12 @@ from typing import List -import numpy as np import ray import torch from tblite.interface import Calculator from torch import Tensor from wurlitzer import pipes -from gflownet.proxy.base import Proxy +from gflownet.proxy.conformers.base import MoleculeEnergyBase @ray.remote @@ -26,14 +25,9 @@ def _chunks(lst, n): yield lst[i : i + n] -class TBLiteMoleculeEnergy(Proxy): - def __init__(self, batch_size=1024, n_samples=5000, **kwargs): - super().__init__(**kwargs) - - self.batch_size = batch_size - self.n_samples = n_samples - self.max_energy = 0 - self.min = 0 +class TBLiteMoleculeEnergy(MoleculeEnergyBase): + def __init__(self, batch_size: int = 1024, n_samples: int = 5000, **kwargs): + super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) def __call__(self, states: List) -> Tensor: energies = [] @@ -46,10 +40,3 @@ def __call__(self, states: List) -> Tensor: energies -= self.max_energy return energies - - def setup(self, env=None): - states = env.statebatch2proxy(2 * np.pi * np.random.rand(self.n_samples, 3)) - energies = self(states) - - self.max_energy = max(energies) - self.min = min(energies) - self.max_energy diff --git a/gflownet/proxy/conformers/torchani.py b/gflownet/proxy/conformers/torchani.py index cfffdf45a..cef2a308a 100644 --- a/gflownet/proxy/conformers/torchani.py +++ b/gflownet/proxy/conformers/torchani.py @@ -4,7 +4,7 @@ import torchani from torch import Tensor -from gflownet.proxy.base import Proxy +from gflownet.proxy.conformers.base import MoleculeEnergyBase TORCHANI_MODELS = { "ANI1x": torchani.models.ANI1x, @@ -13,13 +13,13 @@ } -class TorchANIMoleculeEnergy(Proxy): +class TorchANIMoleculeEnergy(MoleculeEnergyBase): def __init__( self, model: str = "ANI2x", use_ensemble: bool = True, - batch_size: Optional[int] = None, - divider: float = 100.0, + batch_size: Optional[int] = 128, + n_samples: int = 5000, **kwargs, ): """ @@ -33,16 +33,8 @@ def __init__( batch_size : int Batch size for TorchANI. If none, will process all states as a single batch. - - divider : float - The value by which the output of TorchANI will be divided. Necessary for Boltzmann - reward function with high betas, for which the values can explode without division. """ - super().__init__(**kwargs) - - self.batch_size = batch_size - self.divider = divider - self.min = -5 + super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) if TORCHANI_MODELS.get(model) is None: raise ValueError( @@ -99,12 +91,16 @@ def __call__(self, states: Iterable) -> Tensor: else: energies = self.model((elements, coordinates)).energies.float() - return energies / self.divider + energies -= self.max_energy + + return energies def __deepcopy__(self, memo): cls = self.__class__ new_obj = cls.__new__(cls) new_obj.batch_size = self.batch_size + new_obj.n_samples = self.n_samples + new_obj.max_energy = self.max_energy new_obj.min = self.min new_obj.model = self.model return new_obj diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index 1bf069f5e..74fbbcdf3 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -16,16 +16,18 @@ from torch import Tensor from wurlitzer import pipes -from gflownet.proxy.base import Proxy +from gflownet.proxy.conformers.base import MoleculeEnergyBase from gflownet.proxy.conformers.tblite import _chunks from gflownet.utils.molecule.xtb import run_gfn_xtb -def _write_xyz_file(elements: npt.NDArray, coordinates: npt.NDArray, file_path: str) -> None: +def _write_xyz_file( + elements: npt.NDArray, coordinates: npt.NDArray, file_path: str +) -> None: num_atoms = len(elements) - with open(file_path, 'w') as f: - f.write(str(num_atoms) + '\n') - f.write('\n') + with open(file_path, "w") as f: + f.write(str(num_atoms) + "\n") + f.write("\n") for i in range(num_atoms): element = elements[i] @@ -50,15 +52,13 @@ def get_energy(numbers, positions, method="gfnff"): return energy -class XTBMoleculeEnergy(Proxy): - def __init__(self, method: str = "gfnff", batch_size=1024, n_samples=5000, **kwargs): - super().__init__(**kwargs) +class XTBMoleculeEnergy(MoleculeEnergyBase): + def __init__( + self, method: str = "gfnff", batch_size=1024, n_samples=5000, **kwargs + ): + super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) self.method = method - self.batch_size = batch_size - self.n_samples = n_samples - self.max_energy = 0 - self.min = 0 def __call__(self, states: Iterable) -> Tensor: energies = [] @@ -71,10 +71,3 @@ def __call__(self, states: Iterable) -> Tensor: energies -= self.max_energy return energies - - def setup(self, env=None): - states = env.statebatch2proxy(2 * np.pi * np.random.rand(self.n_samples, 3)) - energies = self(states) - - self.max_energy = max(energies) - self.min = min(energies) - self.max_energy From 0edb340a9309fe7756add57ab7720fee88818fdc Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 13:50:31 -0400 Subject: [PATCH 059/100] changed default beta --- config/experiments/conformer.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index a3dc89c45..e48d92737 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -16,6 +16,7 @@ env: n_comp: 5 vonmises_min_concentration: 4 reward_func: boltzmann + reward_beta: 64 # GFlowNet hyperparameters gflownet: From a13e6cf4f6a1608a19606a823f45c75a24527cad Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 14:20:40 -0400 Subject: [PATCH 060/100] method passed in xtb config --- config/proxy/conformers/xtb.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/config/proxy/conformers/xtb.yaml b/config/proxy/conformers/xtb.yaml index 6d35712a0..5f7d13d1f 100644 --- a/config/proxy/conformers/xtb.yaml +++ b/config/proxy/conformers/xtb.yaml @@ -1 +1,3 @@ _target_: gflownet.proxy.conformers.xtb.XTBMoleculeEnergy + +method: gfnff From 8dee21092d6667c129c8306cbe8a5a22220e37d1 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 16:48:46 -0400 Subject: [PATCH 061/100] added xtb to requirements --- setup_conformer_conda.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/setup_conformer_conda.sh b/setup_conformer_conda.sh index d4a40cb27..9c575bdc7 100644 --- a/setup_conformer_conda.sh +++ b/setup_conformer_conda.sh @@ -13,6 +13,7 @@ conda activate $1 conda install mamba -n base -c conda-forge +mamba install xtb -c conda-forge mamba install tblite -c conda-forge mamba install tblite-python -c conda-forge From f8842d07af75b9340f84c2376132633df8b3986f Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 20:42:39 -0400 Subject: [PATCH 062/100] decreased default beta --- config/experiments/conformer.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index e48d92737..4507e90c7 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -16,7 +16,7 @@ env: n_comp: 5 vonmises_min_concentration: 4 reward_func: boltzmann - reward_beta: 64 + reward_beta: 32 # GFlowNet hyperparameters gflownet: From 38e95344bf54756ac355ca8aecc15a3067daeb4c Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 23 Jun 2023 20:43:44 -0400 Subject: [PATCH 063/100] normalizing energies to (0, 1) range --- gflownet/proxy/conformers/base.py | 22 +++++++++++++++++----- gflownet/proxy/conformers/tblite.py | 4 ++-- gflownet/proxy/conformers/torchani.py | 4 +--- gflownet/proxy/conformers/xtb.py | 3 +-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/gflownet/proxy/conformers/base.py b/gflownet/proxy/conformers/base.py index fcee38400..f3f513b4f 100644 --- a/gflownet/proxy/conformers/base.py +++ b/gflownet/proxy/conformers/base.py @@ -1,7 +1,8 @@ -from abc import ABC -from typing import Optional +from abc import ABC, abstractmethod +from typing import List, Optional import numpy as np +from torch import Tensor from gflownet.proxy.base import Proxy @@ -28,11 +29,22 @@ def __init__( self.batch_size = batch_size self.n_samples = n_samples self.max_energy = 0 - self.min = 0 + self.min_energy = 0 + self.min = -1 + + @abstractmethod + def compute_energy(self, states: List) -> Tensor: + pass + + def __call__(self, states: List) -> Tensor: + energies = self.compute_energy(states) + energies = (energies - self.max_energy) / (self.max_energy - self.min_energy) + + return energies def setup(self, env=None): states = env.statebatch2proxy(2 * np.pi * np.random.rand(self.n_samples, 3)) - energies = self(states) + energies = self.compute_energy(states) self.max_energy = max(energies) - self.min = min(energies) - self.max_energy + self.min_energy = min(energies) diff --git a/gflownet/proxy/conformers/tblite.py b/gflownet/proxy/conformers/tblite.py index 93bae15ca..0ee24895e 100644 --- a/gflownet/proxy/conformers/tblite.py +++ b/gflownet/proxy/conformers/tblite.py @@ -12,6 +12,7 @@ @ray.remote def get_energy(numbers, positions): with pipes(): + # The positions are converted from Angstrom to Bohr. calc = Calculator("GFN2-xTB", numbers, positions * 1.8897259886) res = calc.singlepoint() energy = res.get("energy").item() @@ -29,7 +30,7 @@ class TBLiteMoleculeEnergy(MoleculeEnergyBase): def __init__(self, batch_size: int = 1024, n_samples: int = 5000, **kwargs): super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) - def __call__(self, states: List) -> Tensor: + def compute_energy(self, states: List) -> Tensor: energies = [] for batch in _chunks(states, self.batch_size): @@ -37,6 +38,5 @@ def __call__(self, states: List) -> Tensor: energies.extend(ray.get(tasks)) energies = torch.tensor(energies, dtype=self.float, device=self.device) - energies -= self.max_energy return energies diff --git a/gflownet/proxy/conformers/torchani.py b/gflownet/proxy/conformers/torchani.py index cef2a308a..c5b229e68 100644 --- a/gflownet/proxy/conformers/torchani.py +++ b/gflownet/proxy/conformers/torchani.py @@ -47,7 +47,7 @@ def __init__( ).to(self.device) @torch.no_grad() - def __call__(self, states: Iterable) -> Tensor: + def compute_energy(self, states: Iterable) -> Tensor: """ Args ---- @@ -91,8 +91,6 @@ def __call__(self, states: Iterable) -> Tensor: else: energies = self.model((elements, coordinates)).energies.float() - energies -= self.max_energy - return energies def __deepcopy__(self, memo): diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index 74fbbcdf3..42b81d447 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -60,7 +60,7 @@ def __init__( self.method = method - def __call__(self, states: Iterable) -> Tensor: + def compute_energy(self, states: Iterable) -> Tensor: energies = [] for batch in _chunks(states, self.batch_size): @@ -68,6 +68,5 @@ def __call__(self, states: Iterable) -> Tensor: energies.extend(ray.get(tasks)) energies = torch.tensor(energies, dtype=self.float, device=self.device) - energies -= self.max_energy return energies From 58c4164ec922c783c8701ec8aac9d90d0ec4ad90 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Sat, 24 Jun 2023 11:19:48 -0400 Subject: [PATCH 064/100] method dictionary for XTB --- gflownet/proxy/conformers/xtb.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index 42b81d447..406858a54 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -21,6 +21,9 @@ from gflownet.utils.molecule.xtb import run_gfn_xtb +METHODS = {"gfn2": "gfn 2", "gfnff": "gfnff"} + + def _write_xyz_file( elements: npt.NDArray, coordinates: npt.NDArray, file_path: str ) -> None: @@ -58,7 +61,11 @@ def __init__( ): super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) - self.method = method + if method not in METHODS.keys(): + raise ValueError( + f"Unrecognized method: {method}, expected one from {METHODS.keys()}." + ) + self.method = METHODS[method] def compute_energy(self, states: Iterable) -> Tensor: energies = [] From ae2e70d6afaea2bb02b72b1493f89317b5c4f276 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Mon, 26 Jun 2023 11:56:12 -0400 Subject: [PATCH 065/100] joblib instead of ray --- config/experiments/conformer.yaml | 3 --- config/main.yaml | 3 --- gflownet/proxy/conformers/tblite.py | 18 +++++++++++++----- gflownet/proxy/conformers/xtb.py | 14 ++++++++++---- main.py | 5 ----- setup_conformer_conda.sh | 3 +-- 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index 4507e90c7..0a7923cff 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -6,9 +6,6 @@ defaults: - override /proxy: conformers/tblite - override /logger: wandb -# Number of parallel ray jobs (XTB greatly benefits from parallelization) -n_jobs: 12 - # Environment env: length_traj: 10 diff --git a/config/main.yaml b/config/main.yaml index e06b57762..f28bb9367 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -14,8 +14,6 @@ float_precision: 32 n_samples: 1000 # Random seeds seed: 0 -# Number of parallel ray jobs -n_jobs: 1 # Hydra config hydra: @@ -26,4 +24,3 @@ hydra: # See: https://hydra.cc/docs/upgrades/1.1_to_1.2/changes_to_job_working_dir/ # See: https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory/#disable-changing-current-working-dir-to-jobs-output-dir chdir: True - diff --git a/gflownet/proxy/conformers/tblite.py b/gflownet/proxy/conformers/tblite.py index 0ee24895e..b80cda21d 100644 --- a/gflownet/proxy/conformers/tblite.py +++ b/gflownet/proxy/conformers/tblite.py @@ -1,15 +1,17 @@ +# This needs to be imported first due to conda/pip package conflicts. +from tblite.interface import Calculator + +import os from typing import List -import ray import torch -from tblite.interface import Calculator +from joblib import delayed, Parallel from torch import Tensor from wurlitzer import pipes from gflownet.proxy.conformers.base import MoleculeEnergyBase -@ray.remote def get_energy(numbers, positions): with pipes(): # The positions are converted from Angstrom to Bohr. @@ -31,11 +33,17 @@ def __init__(self, batch_size: int = 1024, n_samples: int = 5000, **kwargs): super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) def compute_energy(self, states: List) -> Tensor: + # Get the number of available CPUs. + n_jobs = len(os.sched_getaffinity(0)) + energies = [] for batch in _chunks(states, self.batch_size): - tasks = [get_energy.remote(s[:, 0], s[:, 1:]) for s in batch] - energies.extend(ray.get(tasks)) + energies.extend( + Parallel(n_jobs=n_jobs)( + delayed(get_energy)(s[:, 0], s[:, 1:]) for s in batch + ) + ) energies = torch.tensor(energies, dtype=self.float, device=self.device) diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index 406858a54..f77c17239 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -5,14 +5,15 @@ except: pass +import os from pathlib import Path from tempfile import TemporaryDirectory from typing import Iterable import numpy as np import numpy.typing as npt -import ray import torch +from joblib import delayed, Parallel from torch import Tensor from wurlitzer import pipes @@ -39,7 +40,6 @@ def _write_xyz_file( f.write(line) -@ray.remote def get_energy(numbers, positions, method="gfnff"): directory = TemporaryDirectory() file_name = "input.xyz" @@ -68,11 +68,17 @@ def __init__( self.method = METHODS[method] def compute_energy(self, states: Iterable) -> Tensor: + # Get the number of available CPUs. + n_jobs = len(os.sched_getaffinity(0)) + energies = [] for batch in _chunks(states, self.batch_size): - tasks = [get_energy.remote(s[:, 0], s[:, 1:], self.method) for s in batch] - energies.extend(ray.get(tasks)) + energies.extend( + Parallel(n_jobs=n_jobs)( + delayed(get_energy)(s[:, 0], s[:, 1:], self.method) for s in batch + ) + ) energies = torch.tensor(energies, dtype=self.float, device=self.device) diff --git a/main.py b/main.py index d4175868c..2fdb1b791 100644 --- a/main.py +++ b/main.py @@ -16,15 +16,10 @@ import hydra import pandas as pd -import ray -import yaml -from omegaconf import DictConfig, OmegaConf @hydra.main(config_path="./config", config_name="main", version_base="1.1") def main(config): - ray.init(num_cpus=config.n_jobs, log_to_driver=False) - # Get current directory and set it as root log dir for Logger cwd = os.getcwd() config.logger.logdir.root = cwd diff --git a/setup_conformer_conda.sh b/setup_conformer_conda.sh index 9c575bdc7..96efcd4e5 100644 --- a/setup_conformer_conda.sh +++ b/setup_conformer_conda.sh @@ -9,7 +9,7 @@ module --force purge module load cuda/11.7 conda create -n $1 python=3.8 -conda activate $1 +source activate $1 conda install mamba -n base -c conda-forge @@ -25,7 +25,6 @@ python -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spl # Install DGL (see https://www.dgl.ai/pages/start.html) python -m pip install dgl -f https://data.dgl.ai/wheels/cu117/repo.html # Requirements to run -python -m pip install ray ray[tune] ray[default] python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani pytorch3d rdkit wurlitzer # Conditional requirements python -m pip install wandb matplotlib plotly pymatgen gdown From e5962a529b99e9cd12ce61f278f587f61ae5e2ff Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 27 Jun 2023 10:50:19 -0400 Subject: [PATCH 066/100] normalization controlled by an argument --- gflownet/proxy/conformers/base.py | 22 ++++++++++++++++++---- gflownet/proxy/conformers/tblite.py | 12 ++++++++++-- gflownet/proxy/conformers/torchani.py | 9 ++++++++- gflownet/proxy/conformers/xtb.py | 11 +++++++++-- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/gflownet/proxy/conformers/base.py b/gflownet/proxy/conformers/base.py index f3f513b4f..5806c321e 100644 --- a/gflownet/proxy/conformers/base.py +++ b/gflownet/proxy/conformers/base.py @@ -12,6 +12,7 @@ def __init__( self, batch_size: Optional[int] = 128, n_samples: int = 5000, + normalize: bool = True, **kwargs, ): """ @@ -23,14 +24,19 @@ def __init__( n_samples : int Number of samples that will be used to estimate minimum and maximum energy. + + normalize : bool + Whether to truncate the energies to a (0, 1) range (estimated based on + sample conformers). """ super().__init__(**kwargs) self.batch_size = batch_size self.n_samples = n_samples - self.max_energy = 0 - self.min_energy = 0 - self.min = -1 + self.normalize = normalize + self.max_energy = None + self.min_energy = None + self.min = None @abstractmethod def compute_energy(self, states: List) -> Tensor: @@ -38,7 +44,10 @@ def compute_energy(self, states: List) -> Tensor: def __call__(self, states: List) -> Tensor: energies = self.compute_energy(states) - energies = (energies - self.max_energy) / (self.max_energy - self.min_energy) + energies = energies - self.max_energy + + if self.normalize: + energies = energies / (self.max_energy - self.min_energy) return energies @@ -48,3 +57,8 @@ def setup(self, env=None): self.max_energy = max(energies) self.min_energy = min(energies) + + if self.normalize: + self.min = -1 + else: + self.min = self.min_energy - self.max_energy diff --git a/gflownet/proxy/conformers/tblite.py b/gflownet/proxy/conformers/tblite.py index b80cda21d..a24789d75 100644 --- a/gflownet/proxy/conformers/tblite.py +++ b/gflownet/proxy/conformers/tblite.py @@ -29,8 +29,16 @@ def _chunks(lst, n): class TBLiteMoleculeEnergy(MoleculeEnergyBase): - def __init__(self, batch_size: int = 1024, n_samples: int = 5000, **kwargs): - super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) + def __init__( + self, + batch_size: int = 1024, + n_samples: int = 5000, + normalize: bool = True, + **kwargs + ): + super().__init__( + batch_size=batch_size, n_samples=n_samples, normalize=normalize, **kwargs + ) def compute_energy(self, states: List) -> Tensor: # Get the number of available CPUs. diff --git a/gflownet/proxy/conformers/torchani.py b/gflownet/proxy/conformers/torchani.py index c5b229e68..7e5f60785 100644 --- a/gflownet/proxy/conformers/torchani.py +++ b/gflownet/proxy/conformers/torchani.py @@ -20,6 +20,7 @@ def __init__( use_ensemble: bool = True, batch_size: Optional[int] = 128, n_samples: int = 5000, + normalize: bool = True, **kwargs, ): """ @@ -33,8 +34,14 @@ def __init__( batch_size : int Batch size for TorchANI. If none, will process all states as a single batch. + + normalize : bool + Whether to truncate the energies to a (0, 1) range (estimated based on + sample conformers). """ - super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) + super().__init__( + batch_size=batch_size, n_samples=n_samples, normalize=normalize, **kwargs + ) if TORCHANI_MODELS.get(model) is None: raise ValueError( diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index f77c17239..93019d2a9 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -57,9 +57,16 @@ def get_energy(numbers, positions, method="gfnff"): class XTBMoleculeEnergy(MoleculeEnergyBase): def __init__( - self, method: str = "gfnff", batch_size=1024, n_samples=5000, **kwargs + self, + method: str = "gfnff", + batch_size=1024, + n_samples=5000, + normalize: bool = True, + **kwargs, ): - super().__init__(batch_size=batch_size, n_samples=n_samples, **kwargs) + super().__init__( + batch_size=batch_size, n_samples=n_samples, normalize=normalize, **kwargs + ) if method not in METHODS.keys(): raise ValueError( From 9d5c3ab00e87baded602ce249d99e93ac5bd104f Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 27 Jun 2023 11:08:47 -0400 Subject: [PATCH 067/100] renamed utils/molecule/xtb to xtb_cli --- gflownet/proxy/conformers/xtb.py | 2 +- gflownet/utils/molecule/{xtb.py => xtb_cli.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename gflownet/utils/molecule/{xtb.py => xtb_cli.py} (100%) diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index 93019d2a9..fa12ec61a 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -19,7 +19,7 @@ from gflownet.proxy.conformers.base import MoleculeEnergyBase from gflownet.proxy.conformers.tblite import _chunks -from gflownet.utils.molecule.xtb import run_gfn_xtb +from gflownet.utils.molecule.xtb_cli import run_gfn_xtb METHODS = {"gfn2": "gfn 2", "gfnff": "gfnff"} diff --git a/gflownet/utils/molecule/xtb.py b/gflownet/utils/molecule/xtb_cli.py similarity index 100% rename from gflownet/utils/molecule/xtb.py rename to gflownet/utils/molecule/xtb_cli.py From e8596bc9fd56cce0fa3cd4c93c6e538348cd10dc Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 27 Jun 2023 11:13:53 -0400 Subject: [PATCH 068/100] function for finding rotatable bonds --- gflownet/utils/molecule/rotatable_bonds.py | 65 ++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 gflownet/utils/molecule/rotatable_bonds.py diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py new file mode 100644 index 000000000..03941f8c6 --- /dev/null +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -0,0 +1,65 @@ +# Taken from https://pyxtal.readthedocs.io/en/latest/_modules/pyxtal/molecule.html. + +from operator import itemgetter + + +def find_rotor_from_smile(smile): + """ + Find the positions of rotatable bonds in the molecule. + """ + + def cleaner(list_to_clean, neighbors): + """ + Remove duplicate torsion from a list of atom index tuples. + """ + + for_remove = [] + for x in reversed(range(len(list_to_clean))): + ix0 = itemgetter(0)(list_to_clean[x]) + ix3 = itemgetter(3)(list_to_clean[x]) + # for i-j-k-l, we don't want i, l are the ending members + # here C-C-S=O is not a good choice since O is only 1-coordinated + if neighbors[ix0] > 1 and neighbors[ix3] > 1: + for y in reversed(range(x)): + ix1 = itemgetter(1)(list_to_clean[x]) + ix2 = itemgetter(2)(list_to_clean[x]) + iy1 = itemgetter(1)(list_to_clean[y]) + iy2 = itemgetter(2)(list_to_clean[y]) + if [ix1, ix2] == [iy1, iy2] or [ix1, ix2] == [iy2, iy1]: + for_remove.append(y) + else: + for_remove.append(x) + clean_list = [] + for i, v in enumerate(list_to_clean): + if i not in set(for_remove): + clean_list.append(v) + return clean_list + + if smile in ["Cl-", "F-", "Br-", "I-", "Li+", "Na+"]: + return [] + else: + from rdkit import Chem + + smarts_torsion1 = "[*]~[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]~[*]" + smarts_torsion2 = "[*]~[^2]=[^2]~[*]" # C=C bonds + # smarts_torsion2="[*]~[^1]#[^1]~[*]" # C-C triples bonds, to be fixed + + mol = Chem.MolFromSmiles(smile) + N_atom = mol.GetNumAtoms() + neighbors = [len(a.GetNeighbors()) for a in mol.GetAtoms()] + # make sure that the ending members will be counted + neighbors[0] += 1 + neighbors[-1] += 1 + patn_tor1 = Chem.MolFromSmarts(smarts_torsion1) + torsion1 = cleaner(list(mol.GetSubstructMatches(patn_tor1)), neighbors) + patn_tor2 = Chem.MolFromSmarts(smarts_torsion2) + torsion2 = cleaner(list(mol.GetSubstructMatches(patn_tor2)), neighbors) + tmp = cleaner(torsion1 + torsion2, neighbors) + torsions = [] + for t in tmp: + (i, j, k, l) = t + b = mol.GetBondBetweenAtoms(j, k) + if not b.IsInRing(): + torsions.append(t) + # if len(torsions) > 6: torsions[1] = (4, 7, 10, 15) + return torsions From bc812363343d504040981810cc9456cbdd863480 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 27 Jun 2023 12:52:53 -0400 Subject: [PATCH 069/100] dynamically computing torsion angles (instead of using dataset) --- config/env/conformers/conformer.yaml | 6 ++-- gflownet/envs/conformers/conformer.py | 41 ++++++++++++++++++++------- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/config/env/conformers/conformer.yaml b/config/env/conformers/conformer.yaml index b0e7f21ee..53d89bb50 100644 --- a/config/env/conformers/conformer.yaml +++ b/config/env/conformers/conformer.yaml @@ -6,9 +6,9 @@ _target_: gflownet.envs.conformers.conformer.Conformer # smiles: 'CC(C(=O)NC)NC(=O)C' # alanine dipeptide # smiles: 'CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O' # ibuprofen smiles: 'O=C(c1ccc2n1CCC2C(=O)O)c3ccccc3' # ketorolac - -path_to_dataset: './data/conformers.npy' -url_to_dataset: 'https://drive.google.com/uc?id=1MefikIedDjUtUJtzCHwXhZtpXHJX3ImA' +# smiles: 'CCCCCC1=CC(=C(C(=C1)O)C2C=C(CCC2C(=C)C)C)O' # cannabidiol +# smiles: 'CN1C2CCC1C(C(C2)OC(=O)C3=CC=CC=C3)C(=O)OC' # cocaine +n_torsion_angles: 2 id: conformer policy_encoding_dim_per_angle: null diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 0101cc734..58c8da184 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -1,34 +1,40 @@ import copy -from typing import List +from typing import List, Optional, Tuple import numpy as np import numpy.typing as npt +from rdkit import Chem +from rdkit.Chem import AllChem from torchtyping import TensorType from gflownet.envs.ctorus import ContinuousTorus -from gflownet.utils.molecule.datasets import AtomPositionsDataset from gflownet.utils.molecule.rdkit_conformer import RDKitConformer +from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smile class Conformer(ContinuousTorus): """ Extension of continuous torus to conformer generation. Based on AlanineDipeptide, - but accepts any molecule (defined by SMILES, freely rotatable torsion angles, and - path to dataset containing sample conformers). + but accepts any molecule (defined by SMILES and freely rotatable torsion angles). """ def __init__( self, smiles: str, - path_to_dataset: str, - url_to_dataset: str, + n_torsion_angles: Optional[int] = 2, + torsion_indices: Optional[List[int]] = None, **kwargs, ): - atom_positions_dataset = AtomPositionsDataset( - smiles, path_to_dataset, url_to_dataset - ) - atom_positions = atom_positions_dataset.first() - torsion_angles = atom_positions_dataset.torsion_angles + if torsion_indices is None: + # We hard code default torsion indices for Alanine Dipeptide to preserve + # backward compatibility. + if smiles == "CC(C(=O)NC)NC(=O)C" and n_torsion_angles == 2: + torsion_indices = [2, 1] + else: + torsion_indices = list(range(n_torsion_angles)) + + atom_positions = Conformer._get_positions(smiles) + torsion_angles = Conformer._get_torsion_angles(smiles, torsion_indices) self.conformer = RDKitConformer(atom_positions, smiles, torsion_angles) # Conversions @@ -39,6 +45,19 @@ def __init__( self.sync_conformer_with_state() + @staticmethod + def _get_positions(smiles: str) -> npt.NDArray: + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) + AllChem.EmbedMolecule(mol, randomSeed=0) + return mol.GetConformer().GetPositions() + + @staticmethod + def _get_torsion_angles(smiles: str, indices: List[int]) -> List[Tuple[int]]: + torsion_angles = find_rotor_from_smile(smiles) + torsion_angles = [torsion_angles[i] for i in indices] + return torsion_angles + def sync_conformer_with_state(self, state: List = None): if state is None: state = self.state From b3f4e641bc8c56907482277395b79bab07778e5b Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 26 Jul 2023 14:43:40 -0400 Subject: [PATCH 070/100] race condition fix and temporary workaround for white spaces in hydra arguments --- gflownet/utils/logger.py | 3 +++ main.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 1729a1e36..3d8e06238 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -44,6 +44,9 @@ def __init__( run_name = "{}".format( date_time, ) + else: + # TODO: remove + run_name = run_name.replace("_", " ") if self.do.online: import wandb diff --git a/main.py b/main.py index 2fdb1b791..152743f64 100644 --- a/main.py +++ b/main.py @@ -22,6 +22,11 @@ def main(config): # Get current directory and set it as root log dir for Logger cwd = os.getcwd() + # TODO: fix race condition in a more elegant way + import random + cwd += "/%08x" % random.getrandbits(32) + os.mkdir(cwd) + os.chdir(cwd) config.logger.logdir.root = cwd print(f"\nLogging directory of this run: {cwd}\n") From 583854e50623c622c570216239cc537c97bc2249 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Wed, 2 Aug 2023 01:32:02 -0400 Subject: [PATCH 071/100] implemented ns --- config/env/conformers/conformer.yaml | 1 + config/experiments/conformer.yaml | 1 + gflownet/envs/conformers/conformer.py | 7 ++++- gflownet/envs/htorus.py | 37 ++++++++++++++++++++++++++- setup_conformer.sh | 4 +++ 5 files changed, 48 insertions(+), 2 deletions(-) diff --git a/config/env/conformers/conformer.yaml b/config/env/conformers/conformer.yaml index 53d89bb50..6fb403c7d 100644 --- a/config/env/conformers/conformer.yaml +++ b/config/env/conformers/conformer.yaml @@ -9,6 +9,7 @@ smiles: 'O=C(c1ccc2n1CCC2C(=O)O)c3ccccc3' # ketorolac # smiles: 'CCCCCC1=CC(=C(C(=C1)O)C2C=C(CCC2C(=C)C)C)O' # cannabidiol # smiles: 'CN1C2CCC1C(C(C2)OC(=O)C3=CC=CC=C3)C(=O)OC' # cocaine n_torsion_angles: 2 +reward_sampling_method: rejection id: conformer policy_encoding_dim_per_angle: null diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index 0a7923cff..03db3e021 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -14,6 +14,7 @@ env: vonmises_min_concentration: 4 reward_func: boltzmann reward_beta: 32 + reward_sampling_method: rejection # GFlowNet hyperparameters gflownet: diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 58c8da184..f2f6d35e0 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -1,4 +1,5 @@ import copy +import torch from typing import List, Optional, Tuple import numpy as np @@ -94,7 +95,11 @@ def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]: def statetorch2kde( self, states: TensorType["batch_size", "state_dim"] ) -> TensorType["batch_size", "state_proxy_dim"]: - return states.cpu().numpy()[:, :-1] + if torch.is_tensor(states): + # why is this [:, :-1] needed? + return states.cpu().numpy()[:, :-1] + else: + return states def __deepcopy__(self, memo): cls = self.__class__ diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index b31de4f1a..ae337664e 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -53,6 +53,7 @@ def __init__( "vonmises_mean": 0.0, "vonmises_concentration": 0.001, }, + reward_sampling_method = 'rejection', **kwargs, ): assert n_dim > 0 @@ -77,6 +78,7 @@ def __init__( # TODO: assess if really needed self.state2oracle = self.state2proxy self.statebatch2oracle = self.statebatch2proxy + self.reward_sampling_method = reward_sampling_method # Base class init super().__init__( fixed_distribution=fixed_distribution, @@ -531,7 +533,15 @@ def get_grid_terminating_states(self, n_states: int) -> List[List]: return states # TODO: make generic for all environments - def sample_from_reward( + def sample_from_reward(self, n_samples: int, epsilon=1e-4, method='rejection_sampling' + ) -> TensorType["n_samples", "state_dim"]: + if self.reward_sampling_method == 'rejection': + return self.sample_from_reward_rejection(n_samples, epsilon) + elif self.reward_sampling_method == 'nested': + print("Warning: nested sampling ignores parameter n_samples and samples as many points as it wants (no idea why exactly, TBD)") + return self.sample_from_reward_nested() + + def sample_from_reward_rejection( self, n_samples: int, epsilon=1e-4 ) -> TensorType["n_samples", "state_dim"]: """ @@ -576,6 +586,31 @@ def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): aug_samples = np.concatenate(aug_samples) kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(aug_samples) return kde + + def sample_from_reward_nested(self): + import ultranest + def reward_func(angles): + angles = torch.tensor(angles) + rewards = self.reward_torchbatch(angles) + return np.log(rewards.cpu().detach().numpy()) + + def prior_transform(cube): + params = cube.copy() + + # transform location parameter: uniform prior + low = 0 + high = 2*np.pi + for idx, elem in enumerate(cube): + params[idx] = elem * (high - low) + low + return params + + param_names = [f'theta_{i}' for i in range(self.n_dim)] + sampler = ultranest.ReactiveNestedSampler(param_names, reward_func, prior_transform, + vectorized=True, ndraw_min=1000) + result = sampler.run() + + samples = result['samples'] + return samples def plot_reward_samples( self, diff --git a/setup_conformer.sh b/setup_conformer.sh index 34ed6e3fc..33d2593aa 100644 --- a/setup_conformer.sh +++ b/setup_conformer.sh @@ -20,5 +20,9 @@ python -m pip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html python -m pip install numpy pandas hydra-core tqdm torchtyping six xtb scikit-learn torchani pytorch3d # Conditional requirements python -m pip install wandb matplotlib plotly pymatgen gdown +# for nested sampling +python -m pip install ultranest +# debugging +python -m pip install ipdb # Dev packages # python -m pip install black flake8 isort pylint ipdb jupyter pytest pytest-repeat From 6988ea605de9e32b4acfaeb240ded732c5af90af Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Wed, 2 Aug 2023 02:34:05 -0400 Subject: [PATCH 072/100] fix n_samples for ns --- gflownet/envs/htorus.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index ae337664e..af9aee4f6 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -538,8 +538,8 @@ def sample_from_reward(self, n_samples: int, epsilon=1e-4, method='rejection_sam if self.reward_sampling_method == 'rejection': return self.sample_from_reward_rejection(n_samples, epsilon) elif self.reward_sampling_method == 'nested': - print("Warning: nested sampling ignores parameter n_samples and samples as many points as it wants (no idea why exactly, TBD)") - return self.sample_from_reward_nested() + # print("Warning: nested sampling ignores parameter n_samples and samples as many points as it wants (no idea why exactly, TBD)") + return self.sample_from_reward_nested(n_samples) def sample_from_reward_rejection( self, n_samples: int, epsilon=1e-4 @@ -587,7 +587,7 @@ def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(aug_samples) return kde - def sample_from_reward_nested(self): + def sample_from_reward_nested(self, n_samples): import ultranest def reward_func(angles): angles = torch.tensor(angles) @@ -604,13 +604,20 @@ def prior_transform(cube): params[idx] = elem * (high - low) + low return params - param_names = [f'theta_{i}' for i in range(self.n_dim)] - sampler = ultranest.ReactiveNestedSampler(param_names, reward_func, prior_transform, + samples = [] + n_sampled = 0 + while n_sampled < n_samples: + param_names = [f'theta_{i}' for i in range(self.n_dim)] + sampler = ultranest.ReactiveNestedSampler(param_names, reward_func, prior_transform, vectorized=True, ndraw_min=1000) - result = sampler.run() - - samples = result['samples'] - return samples + result = sampler.run() + + samples.append(result['samples']) + n_sampled += result['samples'].shape[0] + print(f"Total samples: {n_sampled}") + samples = np.concatenate(samples, axis=0) + np.random.shuffle(samples) + return samples[:n_samples] def plot_reward_samples( self, From 90e71554f166ef509817fa5c03d687fc010f9ab3 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 8 Aug 2023 17:47:56 -0400 Subject: [PATCH 073/100] suppressing ultranest output; better state conversion; more verbose prints --- config/experiments/conformer.yaml | 2 +- gflownet/envs/conformers/conformer.py | 7 +--- gflownet/envs/htorus.py | 52 ++++++++++++++++----------- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index 03db3e021..4ae2a0b77 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -14,7 +14,7 @@ env: vonmises_min_concentration: 4 reward_func: boltzmann reward_beta: 32 - reward_sampling_method: rejection + reward_sampling_method: nested # GFlowNet hyperparameters gflownet: diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index f2f6d35e0..58c8da184 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -1,5 +1,4 @@ import copy -import torch from typing import List, Optional, Tuple import numpy as np @@ -95,11 +94,7 @@ def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]: def statetorch2kde( self, states: TensorType["batch_size", "state_dim"] ) -> TensorType["batch_size", "state_proxy_dim"]: - if torch.is_tensor(states): - # why is this [:, :-1] needed? - return states.cpu().numpy()[:, :-1] - else: - return states + return states.cpu().numpy()[:, :-1] def __deepcopy__(self, memo): cls = self.__class__ diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index af9aee4f6..fd5f0e05f 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -9,7 +9,6 @@ import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt -import pandas as pd import torch from sklearn.neighbors import KernelDensity from torch.distributions import Bernoulli, Categorical, Uniform, VonMises @@ -53,7 +52,7 @@ def __init__( "vonmises_mean": 0.0, "vonmises_concentration": 0.001, }, - reward_sampling_method = 'rejection', + reward_sampling_method="rejection", **kwargs, ): assert n_dim > 0 @@ -533,14 +532,15 @@ def get_grid_terminating_states(self, n_states: int) -> List[List]: return states # TODO: make generic for all environments - def sample_from_reward(self, n_samples: int, epsilon=1e-4, method='rejection_sampling' + def sample_from_reward( + self, n_samples: int, epsilon=1e-4, method="rejection_sampling" ) -> TensorType["n_samples", "state_dim"]: - if self.reward_sampling_method == 'rejection': + if self.reward_sampling_method == "rejection": return self.sample_from_reward_rejection(n_samples, epsilon) - elif self.reward_sampling_method == 'nested': + elif self.reward_sampling_method == "nested": # print("Warning: nested sampling ignores parameter n_samples and samples as many points as it wants (no idea why exactly, TBD)") return self.sample_from_reward_nested(n_samples) - + def sample_from_reward_rejection( self, n_samples: int, epsilon=1e-4 ) -> TensorType["n_samples", "state_dim"]: @@ -586,38 +586,50 @@ def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): aug_samples = np.concatenate(aug_samples) kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(aug_samples) return kde - + def sample_from_reward_nested(self, n_samples): import ultranest + from wurlitzer import pipes + def reward_func(angles): angles = torch.tensor(angles) rewards = self.reward_torchbatch(angles) return np.log(rewards.cpu().detach().numpy()) - + def prior_transform(cube): params = cube.copy() - # transform location parameter: uniform prior low = 0 - high = 2*np.pi + high = 2 * np.pi for idx, elem in enumerate(cube): params[idx] = elem * (high - low) + low return params - + samples = [] n_sampled = 0 + iteration = 0 + print(f"Running nested sampling (until {n_samples} samples are obtained)...") while n_sampled < n_samples: - param_names = [f'theta_{i}' for i in range(self.n_dim)] - sampler = ultranest.ReactiveNestedSampler(param_names, reward_func, prior_transform, - vectorized=True, ndraw_min=1000) - result = sampler.run() - - samples.append(result['samples']) - n_sampled += result['samples'].shape[0] - print(f"Total samples: {n_sampled}") + param_names = [f"theta_{i}" for i in range(self.n_dim)] + + with pipes(): + sampler = ultranest.ReactiveNestedSampler( + param_names, + reward_func, + prior_transform, + vectorized=True, + ndraw_min=1000, + ) + result = sampler.run() + + samples.append(result["samples"]) + n_sampled += result["samples"].shape[0] + print(f"Total samples (iteration #{iteration}): {n_sampled}.") + iteration += 1 samples = np.concatenate(samples, axis=0) + samples = np.concatenate([samples, np.ones((samples.shape[0], 1))], axis=1) np.random.shuffle(samples) - return samples[:n_samples] + return torch.Tensor(samples[:n_samples]) def plot_reward_samples( self, From 38f5d624474d71d30c40a2d318a47ef16308ec31 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 8 Aug 2023 17:48:37 -0400 Subject: [PATCH 074/100] more consistent configs --- config/env/conformers/conformer.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/env/conformers/conformer.yaml b/config/env/conformers/conformer.yaml index 6fb403c7d..6285f4120 100644 --- a/config/env/conformers/conformer.yaml +++ b/config/env/conformers/conformer.yaml @@ -9,7 +9,7 @@ smiles: 'O=C(c1ccc2n1CCC2C(=O)O)c3ccccc3' # ketorolac # smiles: 'CCCCCC1=CC(=C(C(=C1)O)C2C=C(CCC2C(=C)C)C)O' # cannabidiol # smiles: 'CN1C2CCC1C(C(C2)OC(=O)C3=CC=CC=C3)C(=O)OC' # cocaine n_torsion_angles: 2 -reward_sampling_method: rejection +reward_sampling_method: nested id: conformer policy_encoding_dim_per_angle: null From a4dd97398c5bb7b93fdc75bc102985593f8b097a Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 8 Aug 2023 17:54:36 -0400 Subject: [PATCH 075/100] commented out print changed to TODO --- gflownet/envs/htorus.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index fd5f0e05f..804f8cea4 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -538,7 +538,6 @@ def sample_from_reward( if self.reward_sampling_method == "rejection": return self.sample_from_reward_rejection(n_samples, epsilon) elif self.reward_sampling_method == "nested": - # print("Warning: nested sampling ignores parameter n_samples and samples as many points as it wants (no idea why exactly, TBD)") return self.sample_from_reward_nested(n_samples) def sample_from_reward_rejection( @@ -588,6 +587,8 @@ def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): return kde def sample_from_reward_nested(self, n_samples): + # TODO: nested sampling ignores parameter n_samples and samples + # as many points as it wants (no idea why exactly, TBD) import ultranest from wurlitzer import pipes From afa1488cef8bbe04a90019767e3fdc429de44d90 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 9 Aug 2023 12:17:51 -0400 Subject: [PATCH 076/100] removed unused import --- main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/main.py b/main.py index 152743f64..16baaa81c 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,6 @@ def main(config): # Get current directory and set it as root log dir for Logger cwd = os.getcwd() # TODO: fix race condition in a more elegant way - import random cwd += "/%08x" % random.getrandbits(32) os.mkdir(cwd) os.chdir(cwd) From 452a45a47973a376ba13140105a2fe027fb2cfd0 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 9 Aug 2023 12:28:46 -0400 Subject: [PATCH 077/100] outlier removal and clamping of energies --- gflownet/proxy/conformers/base.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/gflownet/proxy/conformers/base.py b/gflownet/proxy/conformers/base.py index 5806c321e..fa39f2e18 100644 --- a/gflownet/proxy/conformers/base.py +++ b/gflownet/proxy/conformers/base.py @@ -1,3 +1,4 @@ +import warnings from abc import ABC, abstractmethod from typing import List, Optional @@ -13,6 +14,8 @@ def __init__( batch_size: Optional[int] = 128, n_samples: int = 5000, normalize: bool = True, + remove_outliers: bool = True, + clamp: bool = True, **kwargs, ): """ @@ -28,12 +31,26 @@ def __init__( normalize : bool Whether to truncate the energies to a (0, 1) range (estimated based on sample conformers). + + remove_outliers : bool + Whether to adjust the min and max energy values estimated on the sample of + conformers to a +- 3 std range. + + clamp : bool + Whether to clamp the energies to the estimated min and max values. """ super().__init__(**kwargs) + if remove_outliers and not clamp: + warnings.warn( + "If outliers are removed it's recommended to also clamp the values." + ) + self.batch_size = batch_size self.n_samples = n_samples self.normalize = normalize + self.remove_outliers = remove_outliers + self.clamp = clamp self.max_energy = None self.min_energy = None self.min = None @@ -44,6 +61,10 @@ def compute_energy(self, states: List) -> Tensor: def __call__(self, states: List) -> Tensor: energies = self.compute_energy(states) + + if self.clamp: + energies = energies.clamp(self.min_energy, self.max_energy) + energies = energies - self.max_energy if self.normalize: @@ -58,6 +79,10 @@ def setup(self, env=None): self.max_energy = max(energies) self.min_energy = min(energies) + if self.remove_outliers: + self.max_energy = min(self.max_energy, energies.mean() + 3 * energies.std()) + self.min_energy = max(self.min_energy, energies.mean() - 3 * energies.std()) + if self.normalize: self.min = -1 else: From 49450aa5df353716c3bdd9f2c19196c84875a67b Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 9 Aug 2023 12:43:04 -0400 Subject: [PATCH 078/100] more robust conformer sampling (higher number of dimensions) --- gflownet/proxy/conformers/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gflownet/proxy/conformers/base.py b/gflownet/proxy/conformers/base.py index fa39f2e18..447a561b8 100644 --- a/gflownet/proxy/conformers/base.py +++ b/gflownet/proxy/conformers/base.py @@ -73,8 +73,12 @@ def __call__(self, states: List) -> Tensor: return energies def setup(self, env=None): - states = env.statebatch2proxy(2 * np.pi * np.random.rand(self.n_samples, 3)) - energies = self.compute_energy(states) + env_states = 2 * np.pi * np.random.rand(self.n_samples, env.n_dim) + env_states = np.concatenate( + [env_states, np.ones((env_states.shape[0], 1))], axis=1 + ) + proxy_states = env.statebatch2proxy(env_states) + energies = self.compute_energy(proxy_states) self.max_energy = max(energies) self.min_energy = min(energies) From 6045a200851ced86c6efbcd6daca3bfa6e795ca9 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 9 Aug 2023 13:41:21 -0400 Subject: [PATCH 079/100] quantiles used for outlier detection; increased number of samples --- gflownet/proxy/conformers/base.py | 8 ++++---- gflownet/proxy/conformers/tblite.py | 2 +- gflownet/proxy/conformers/torchani.py | 2 +- gflownet/proxy/conformers/xtb.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/gflownet/proxy/conformers/base.py b/gflownet/proxy/conformers/base.py index 447a561b8..6b68c45b0 100644 --- a/gflownet/proxy/conformers/base.py +++ b/gflownet/proxy/conformers/base.py @@ -12,7 +12,7 @@ class MoleculeEnergyBase(Proxy, ABC): def __init__( self, batch_size: Optional[int] = 128, - n_samples: int = 5000, + n_samples: int = 10000, normalize: bool = True, remove_outliers: bool = True, clamp: bool = True, @@ -34,7 +34,7 @@ def __init__( remove_outliers : bool Whether to adjust the min and max energy values estimated on the sample of - conformers to a +- 3 std range. + conformers by removing 0.01 quantiles. clamp : bool Whether to clamp the energies to the estimated min and max values. @@ -84,8 +84,8 @@ def setup(self, env=None): self.min_energy = min(energies) if self.remove_outliers: - self.max_energy = min(self.max_energy, energies.mean() + 3 * energies.std()) - self.min_energy = max(self.min_energy, energies.mean() - 3 * energies.std()) + self.max_energy = np.quantile(energies, 0.99) + self.min_energy = np.quantile(energies, 0.01) if self.normalize: self.min = -1 diff --git a/gflownet/proxy/conformers/tblite.py b/gflownet/proxy/conformers/tblite.py index a24789d75..2948b4b62 100644 --- a/gflownet/proxy/conformers/tblite.py +++ b/gflownet/proxy/conformers/tblite.py @@ -32,7 +32,7 @@ class TBLiteMoleculeEnergy(MoleculeEnergyBase): def __init__( self, batch_size: int = 1024, - n_samples: int = 5000, + n_samples: int = 10000, normalize: bool = True, **kwargs ): diff --git a/gflownet/proxy/conformers/torchani.py b/gflownet/proxy/conformers/torchani.py index 7e5f60785..bc2e689f2 100644 --- a/gflownet/proxy/conformers/torchani.py +++ b/gflownet/proxy/conformers/torchani.py @@ -19,7 +19,7 @@ def __init__( model: str = "ANI2x", use_ensemble: bool = True, batch_size: Optional[int] = 128, - n_samples: int = 5000, + n_samples: int = 10000, normalize: bool = True, **kwargs, ): diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index fa12ec61a..dc74c2acf 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -60,7 +60,7 @@ def __init__( self, method: str = "gfnff", batch_size=1024, - n_samples=5000, + n_samples=10000, normalize: bool = True, **kwargs, ): From 103eaca3eaee74fe14fbf5fcf7103f9f318d0382 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 9 Aug 2023 14:03:02 -0400 Subject: [PATCH 080/100] updated comment --- gflownet/gflownet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 27be310fa..9ea25fd45 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -758,7 +758,7 @@ def test(self, **plot_kwargs): log_density_true = dict_tt["log_density_true"] kde_true = dict_tt["kde_true"] else: - # Sample from reward via rejection sampling + # Sample from reward via rejection or nested sampling x_from_reward = self.env.sample_from_reward( n_samples=self.logger.test.n ) From 675787086c847a10c3f17b74ada00603c78f61c5 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 9 Aug 2023 14:50:18 -0400 Subject: [PATCH 081/100] scaling to higher dimensionality --- gflownet/envs/htorus.py | 7 ++----- gflownet/gflownet.py | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 804f8cea4..866406a1e 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -577,11 +577,8 @@ def sample_from_reward_rejection( def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): aug_samples = [] - for add_0 in [0, -2 * np.pi, 2 * np.pi]: - for add_1 in [0, -2 * np.pi, 2 * np.pi]: - aug_samples.append( - np.stack([samples[:, 0] + add_0, samples[:, 1] + add_1], axis=1) - ) + for offset in itertools.product([0, -2 * np.pi, 2 * np.pi], repeat=self.n_dim): + aug_samples.append(samples + offset) aug_samples = np.concatenate(aug_samples) kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(aug_samples) return kde diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 9ea25fd45..12e7771fb 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -797,11 +797,11 @@ def test(self, **plot_kwargs): # Plots - if hasattr(self.env, "plot_reward_samples"): + if hasattr(self.env, "plot_reward_samples") and self.env.n_dim <= 2: fig_reward_samples = self.env.plot_reward_samples(x_sampled, **plot_kwargs) else: fig_reward_samples = None - if hasattr(self.env, "plot_kde"): + if hasattr(self.env, "plot_kde") and self.env.n_dim <= 2: fig_kde_pred = self.env.plot_kde(kde_pred, **plot_kwargs) fig_kde_true = self.env.plot_kde(kde_true, **plot_kwargs) else: From 31ea9b051caa0c8b22ee79fa9c253b91b22e1c9b Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 10 Aug 2023 09:56:04 -0400 Subject: [PATCH 082/100] updated config --- config/experiments/conformer.yaml | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index 4ae2a0b77..e7ab2a965 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -20,24 +20,29 @@ env: gflownet: random_action_prob: 0.1 optimizer: - batch_size: 100 + batch_size: + forward: 100 + backward_dataset: 0 + backward_replay: 0 lr: 0.00001 z_dim: 16 lr_z_mult: 1000 n_train_steps: 40000 lr_decay_period: 1000000 - policy: - forward: - type: mlp - n_hid: 512 - n_layers: 5 - checkpoint: forward - backward: - type: mlp - n_hid: 512 - n_layers: 5 - shared_weights: False - checkpoint: backward + +# Policy +policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + backward: + type: mlp + n_hid: 512 + n_layers: 5 + shared_weights: False + checkpoint: backward # WandB logger: From 2e44acc676e76ca59e6eefac7b339ba7d3560faa Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Fri, 11 Aug 2023 16:34:01 -0400 Subject: [PATCH 083/100] EGNN-based policy --- config/experiments/conformer.yaml | 16 +--- config/policy/conformers/egnn.yaml | 16 ++++ config/policy/conformers/mlp.yaml | 13 +++ gflownet/envs/conformers/conformer.py | 29 ++++++ gflownet/policy/conformers/egnn.py | 133 ++++++++++++++++++++++++++ 5 files changed, 193 insertions(+), 14 deletions(-) create mode 100644 config/policy/conformers/egnn.yaml create mode 100644 config/policy/conformers/mlp.yaml create mode 100644 gflownet/policy/conformers/egnn.py diff --git a/config/experiments/conformer.yaml b/config/experiments/conformer.yaml index e7ab2a965..1efac2565 100644 --- a/config/experiments/conformer.yaml +++ b/config/experiments/conformer.yaml @@ -3,6 +3,7 @@ defaults: - override /env: conformers/conformer - override /gflownet: trajectorybalance + - override /policy: conformers/mlp - override /proxy: conformers/tblite - override /logger: wandb @@ -10,6 +11,7 @@ defaults: env: length_traj: 10 policy_encoding_dim_per_angle: 10 + policy_type: mlp n_comp: 5 vonmises_min_concentration: 4 reward_func: boltzmann @@ -30,20 +32,6 @@ gflownet: n_train_steps: 40000 lr_decay_period: 1000000 -# Policy -policy: - forward: - type: mlp - n_hid: 512 - n_layers: 5 - checkpoint: forward - backward: - type: mlp - n_hid: 512 - n_layers: 5 - shared_weights: False - checkpoint: backward - # WandB logger: lightweight: True diff --git a/config/policy/conformers/egnn.yaml b/config/policy/conformers/egnn.yaml new file mode 100644 index 000000000..829029b66 --- /dev/null +++ b/config/policy/conformers/egnn.yaml @@ -0,0 +1,16 @@ +_target_: gflownet.policy.conformers.egnn.EGNNPolicy + +forward: + n_gnn_layers: 7 + n_node_mlp_layers: 2 + n_pool_mlp_layers: 2 + egnn_hidden_dim: 128 + node_mlp_hidden_dim: 128 + pool_mlp_hidden_dim: 128 +backward: + n_gnn_layers: 7 + n_node_mlp_layers: 2 + n_pool_mlp_layers: 2 + egnn_hidden_dim: 128 + node_mlp_hidden_dim: 128 + pool_mlp_hidden_dim: 128 diff --git a/config/policy/conformers/mlp.yaml b/config/policy/conformers/mlp.yaml new file mode 100644 index 000000000..8979b892a --- /dev/null +++ b/config/policy/conformers/mlp.yaml @@ -0,0 +1,13 @@ +_target_: gflownet.policy.base.Policy + +forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward +backward: + type: mlp + n_hid: 512 + n_layers: 5 + shared_weights: False + checkpoint: backward diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 58c8da184..821222ffd 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -8,6 +8,8 @@ from torchtyping import TensorType from gflownet.envs.ctorus import ContinuousTorus +from gflownet.utils.molecule.constants import ad_atom_types +from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.rdkit_conformer import RDKitConformer from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smile @@ -23,6 +25,7 @@ def __init__( smiles: str, n_torsion_angles: Optional[int] = 2, torsion_indices: Optional[List[int]] = None, + policy_type: str = "mlp", **kwargs, ): if torsion_indices is None: @@ -40,11 +43,19 @@ def __init__( # Conversions self.statebatch2oracle = self.statebatch2proxy self.statetorch2oracle = self.statetorch2proxy + if policy_type == "gnn": + self.statebatch2policy = self.statebatch2policy_gnn + elif policy_type != "mlp": + raise ValueError( + f"Unrecognized policy_type = {policy_type}, expected either 'mlp' or 'gnn'." + ) super().__init__(n_dim=len(self.conformer.freely_rotatable_tas), **kwargs) self.sync_conformer_with_state() + self.graph = MolDGLFeaturizer(ad_atom_types).mol2dgl(self.conformer.rdk_mol) + @staticmethod def _get_positions(smiles: str) -> npt.NDArray: mol = Chem.MolFromSmiles(smiles) @@ -88,6 +99,24 @@ def statebatch2proxy(self, states: List[List]) -> npt.NDArray: def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: return self.statebatch2proxy(states.cpu().numpy()) + def statebatch2policy_gnn(self, states: List[List]) -> npt.NDArray[np.float32]: + """ + Returns an array of GNN-format policy inputs with dimensionality + (n_states, n_atoms, 4), in which the first three columns encode atom positions, + and the last column encodes current timestep. + """ + policy_input = [] + for state in states: + conformer = self.sync_conformer_with_state(state) + positions = conformer.get_atom_positions() + policy_input.append( + np.concatenate( + [positions, np.full((positions.shape[0], 1), state[-1])], + axis=1, + ) + ) + return np.array(policy_input) + def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]: return np.array(states)[:, :-1] diff --git a/gflownet/policy/conformers/egnn.py b/gflownet/policy/conformers/egnn.py new file mode 100644 index 000000000..386527d00 --- /dev/null +++ b/gflownet/policy/conformers/egnn.py @@ -0,0 +1,133 @@ +from copy import deepcopy +from typing import Optional + +import dgl +import torch +from dgl.nn.pytorch.conv import EGNNConv +from dgl.nn.pytorch.glob import SumPooling +from torch import nn + +from gflownet.policy.base import Policy + + +class EGNNModel(nn.Module): + def __init__( + self, + out_dim: int, + node_feat_dim: int, + edge_feat_dim: int = 0, + n_gnn_layers: int = 7, + n_node_mlp_layers: int = 2, + n_pool_mlp_layers: int = 2, + egnn_hidden_dim: int = 128, + node_mlp_hidden_dim: int = 128, + pool_mlp_hidden_dim: int = 128, + ): + super().__init__() + + self.egnn_layers = [] + for i in range(n_gnn_layers): + self.egnn_layers.append( + EGNNConv( + node_feat_dim if i == 0 else egnn_hidden_dim, + egnn_hidden_dim, + egnn_hidden_dim, + edge_feat_dim, + ) + ) + + node_mlp_layers = [] + for i in range(n_node_mlp_layers): + node_mlp_layers.append( + ( + nn.Linear( + egnn_hidden_dim if i == 0 else node_mlp_hidden_dim, + node_mlp_hidden_dim, + ) + ) + ) + if i < n_node_mlp_layers - 1: + node_mlp_layers.append(nn.SiLU()) + self.node_mlp = nn.Sequential(*node_mlp_layers) + + self.pool = SumPooling() + + pool_mlp_layers = [] + for i in range(n_pool_mlp_layers): + pool_mlp_layers.append( + ( + nn.Linear( + node_mlp_hidden_dim if i == 0 else pool_mlp_hidden_dim, + pool_mlp_hidden_dim if i < n_pool_mlp_layers - 1 else out_dim, + ) + ) + ) + if i < n_pool_mlp_layers - 1: + pool_mlp_layers.append(nn.SiLU()) + self.pool_mlp = nn.Sequential(*pool_mlp_layers) + + def forward( + self, + g: dgl.DGLGraph, + node_feat: torch.Tensor, + coord_feat: torch.Tensor, + edge_feat: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + h, x = node_feat, coord_feat + for egnn_layer in self.egnn_layers: + h, x = egnn_layer(g, h, x, edge_feat) + h = self.node_mlp(h) + h = self.pool(g, h) + h = self.pool_mlp(h) + + return h + + +class EGNNPolicy(Policy): + def __init__(self, config, env, device, float_precision, base=None): + self.model = None + self.config = None + self.graph = env.graph + # We increase the node feature size by 1 to anticipate including current + # timestamp as one of the features. + self.node_feat_dim = env.graph.ndata["atom_features"].shape[1] + 1 + self.edge_feat_dim = env.graph.edata["edge_features"].shape[1] + self.is_model = True + + super().__init__( + config=config, + env=env, + device=device, + float_precision=float_precision, + base=base, + ) + + def parse_config(self, config): + self.config = {} if config is None else config + + def instantiate(self): + self.model = EGNNModel( + self.output_dim, self.node_feat_dim, self.edge_feat_dim, **self.config + ).to(self.device) + + def __call__(self, states: torch.Tensor) -> torch.Tensor: + graphs = [] + for state in states: + graph = deepcopy(self.graph) + graph.ndata["atom_features"] = torch.cat( + [ + graph.ndata["atom_features"], + torch.Tensor(state[:, -1]).unsqueeze(-1), + ], + dim=1, + ) + graph.ndata["coordinates"] = state[:, :-1] + graphs.append(graph) + batch = dgl.batch(graphs) + output = self.model( + batch, + batch.ndata["atom_features"], + batch.ndata["coordinates"], + batch.edata["edge_features"], + ) + return output From 7065121a466b5c5dd0cd8c3ff67083babe1b2730 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Sat, 12 Aug 2023 15:01:33 -0400 Subject: [PATCH 084/100] added missing __init__.py --- gflownet/policy/conformers/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 gflownet/policy/conformers/__init__.py diff --git a/gflownet/policy/conformers/__init__.py b/gflownet/policy/conformers/__init__.py new file mode 100644 index 000000000..e69de29bb From 54cbaf7d5d874614d77da823e0c66ecdc6612d9e Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Sat, 12 Aug 2023 16:01:01 -0400 Subject: [PATCH 085/100] better device support --- gflownet/policy/conformers/egnn.py | 5 +++-- gflownet/proxy/conformers/base.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/gflownet/policy/conformers/egnn.py b/gflownet/policy/conformers/egnn.py index 386527d00..a14c5afc8 100644 --- a/gflownet/policy/conformers/egnn.py +++ b/gflownet/policy/conformers/egnn.py @@ -25,9 +25,9 @@ def __init__( ): super().__init__() - self.egnn_layers = [] + egnn_layers = [] for i in range(n_gnn_layers): - self.egnn_layers.append( + egnn_layers.append( EGNNConv( node_feat_dim if i == 0 else egnn_hidden_dim, egnn_hidden_dim, @@ -35,6 +35,7 @@ def __init__( edge_feat_dim, ) ) + self.egnn_layers = nn.ModuleList(egnn_layers) node_mlp_layers = [] for i in range(n_node_mlp_layers): diff --git a/gflownet/proxy/conformers/base.py b/gflownet/proxy/conformers/base.py index 6b68c45b0..04fd8328e 100644 --- a/gflownet/proxy/conformers/base.py +++ b/gflownet/proxy/conformers/base.py @@ -78,7 +78,7 @@ def setup(self, env=None): [env_states, np.ones((env_states.shape[0], 1))], axis=1 ) proxy_states = env.statebatch2proxy(env_states) - energies = self.compute_energy(proxy_states) + energies = self.compute_energy(proxy_states).cpu().numpy() self.max_energy = max(energies) self.min_energy = min(energies) From 23bb90f0eabd8d54a983c42dd98d8455a9428a9d Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Sat, 12 Aug 2023 16:09:15 -0400 Subject: [PATCH 086/100] better device support --- gflownet/policy/conformers/egnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/policy/conformers/egnn.py b/gflownet/policy/conformers/egnn.py index a14c5afc8..70bf6dd4b 100644 --- a/gflownet/policy/conformers/egnn.py +++ b/gflownet/policy/conformers/egnn.py @@ -114,11 +114,11 @@ def instantiate(self): def __call__(self, states: torch.Tensor) -> torch.Tensor: graphs = [] for state in states: - graph = deepcopy(self.graph) + graph = deepcopy(self.graph).to(self.device) graph.ndata["atom_features"] = torch.cat( [ graph.ndata["atom_features"], - torch.Tensor(state[:, -1]).unsqueeze(-1), + state[:, -1].unsqueeze(-1), ], dim=1, ) From 6448a507bb2a6360eda75314dbf60499dd8f12ff Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Sat, 12 Aug 2023 16:24:18 -0400 Subject: [PATCH 087/100] better device support --- gflownet/envs/htorus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 74390aafd..b0dd6e7aa 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -591,7 +591,7 @@ def sample_from_reward_nested(self, n_samples): from wurlitzer import pipes def reward_func(angles): - angles = torch.tensor(angles) + angles = torch.tensor(angles).to(self.device) rewards = self.reward_torchbatch(angles) return np.log(rewards.cpu().detach().numpy()) From f4da186799da45bc9d4726a4d5d46f1b5a3c6539 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 7 Sep 2023 16:11:54 -0400 Subject: [PATCH 088/100] removed ConformersDataset --- gflownet/utils/molecule/datasets.py | 47 ----------------------------- 1 file changed, 47 deletions(-) diff --git a/gflownet/utils/molecule/datasets.py b/gflownet/utils/molecule/datasets.py index feb49a5cb..738c15371 100644 --- a/gflownet/utils/molecule/datasets.py +++ b/gflownet/utils/molecule/datasets.py @@ -1,9 +1,6 @@ -import dgl import numpy as np from gflownet.utils.common import download_file_if_not_exists -from gflownet.utils.molecule import constants -from gflownet.utils.molecule.dgl_conformer import DGLConformer class AtomPositionsDataset: @@ -26,47 +23,3 @@ def sample(self, size=None): def first(self): return self[0] - - -class ConformersDataset: - def __init__(self, path_to_data, url_to_data): - # TODO create a new dataset if path_to_data or url_to_data doesn't exist - path_to_data = download_file_if_not_exists(path_to_data, url_to_data) - with open(path_to_data, "rb") as inp: - self.conformers = pickle.load(inp) - - def get_conformer(self): - """ - Returns dgl graph with features stored in the dataset: - - ndata: - - atom features - - atomic numbers - - atom position - - edata: - - edge features - - rotatable bonds mask - """ - # TODO make it work if there're several conformers for a single molecule - smiles = np.random.choice(self.conformers.keys()) - edges = self.conformers[smiles]["edges"] - graph = dgl.graph(edges) - graph.ndata[constants.atom_feature_name] = self.conformers[smiles][ - constants.atom_feature_name - ] - graph.ndata[constants.atomic_numbers_name] = self.conformers[smiles][ - constants.atomic_numbers_name - ] - graph.edata[constants.edge_feature_name] = self.conformers[smiles][ - constants.edge_feature_name - ] - graph.edata[constants.rotatable_bonds_mask] = self.conformers[smiles][ - constants.rotatable_bonds_mask - ] - conf_idx = np.random.randint( - 0, self.conformers[smiles][constants.atom_position_name].shape[0] - ) - graph.ndata[constants.atom_position_name] = self.conformers[smiles][ - constants.atom_position_name - ][conf_idx] - conformer = DGLConformer(graph) - return smiles, conformer From 0897394a74fba9584c23b715ab3996b8678938fc Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 7 Sep 2023 16:15:27 -0400 Subject: [PATCH 089/100] remove old_conformer --- gflownet/utils/molecule/old_conformer.py | 130 ----------------------- 1 file changed, 130 deletions(-) delete mode 100644 gflownet/utils/molecule/old_conformer.py diff --git a/gflownet/utils/molecule/old_conformer.py b/gflownet/utils/molecule/old_conformer.py deleted file mode 100644 index b881bf4cf..000000000 --- a/gflownet/utils/molecule/old_conformer.py +++ /dev/null @@ -1,130 +0,0 @@ -from collections import defaultdict -from copy import deepcopy - -import numpy as np -import torch -from rdkit import Chem -from rdkit.Chem import AllChem, TorsionFingerprints, rdMolTransforms -from rdkit.Geometry.rdGeometry import Point3D - -from gflownet.utils.molecule import constants -from gflownet.utils.molecule.conformer_base import ConformerBase -from gflownet.utils.molecule.featurizer import MolDGLFeaturizer -from gflownet.utils.molecule.rdkit_conformer import RDKitConformer - - -class Conformer(RDKitConformer): - def __init__(self, atom_positions, smiles, atom_types, freely_rotatable_tas=None): - """ - :param atom_positions: numpy.ndarray of shape [num_atoms, 3] of dtype float64 - """ - super(Conformer, self).__init__(atom_positions, smiles, freely_rotatable_tas) - - self.featuraiser = MolDGLFeaturizer(atom_types) - # dgl graph is not supposed to be consistent with rdk_conf untill it is returned via .dgl_graph - self._dgl_graph = self.featuraiser.mol2dgl(self.rdk_mol) - self.set_atom_positions_dgl(atom_positions) - self.ta_to_index = defaultdict(lambda: None) - - @property - def dgl_graph(self): - pos = self.get_atom_positions() - self.set_atom_positions_dgl(pos) - return self._dgl_graph - - def set_atom_positions_dgl(self, atom_positions): - """Set atom positions of the self.dgl_graph to the input atom_positions values - :param atom_positions: 2d numpy array of shape [num atoms, 3] with new atom positions - """ - self._dgl_graph.ndata[constants.atom_position_name] = torch.Tensor( - atom_positions - ) - - def apply_actions(self, actions): - """ - Apply torsion angles updates defined by agent's actions - :param actions: a sequence of torsion angle updates of length = number of bonds in the molecule. - The order corresponds to the order of edges in self.dgl_graph, such that action[i] is - an update for the torsion angle corresponding to the edge[2i] - """ - for torsion_angle in self.freely_rotatable_tas: - idx = self.get_ta_index_in_dgl_graph(torsion_angle) - assert idx % 2 == 0 - # actions tensor is 2 times shorter that edges tensor (w/o reversed edges) - idx = int(idx // 2) - increment = actions[idx] - self.increment_torsion_angle(torsion_angle, increment) - - def get_ta_index_in_dgl_graph(self, torsion_angle): - """ - Get an index in the dgl graph of the first edge corresponding to the input torsion_angle - :param torsion_angle: tuple of 4 integers defining torsion angle - (these integers are indexes of the atoms in both self.rdk_mol and self.dgl_graph) - :returns: int, index of the torsion_angle's edge in self.dgl_graph - """ - if self.ta_to_index[torsion_angle] is None: - for idx, (s, d) in enumerate(zip(*self._dgl_graph.edges())): - if torsion_angle[1:3] == (s, d): - self.ta_to_index[torsion_angle] = idx - if self.ta_to_index[torsion_angle] is None: - raise Exception("Cannot find torsion angle {}".format(torsion_angle)) - return self.ta_to_index[torsion_angle] - - -if __name__ == "__main__": - from tabulate import tabulate - - from gflownet.utils.molecule.conformer_base import get_all_torsion_angles - - rmol = Chem.MolFromSmiles(constants.ad_smiles) - rmol = Chem.AddHs(rmol) - AllChem.EmbedMolecule(rmol) - rconf = rmol.GetConformer() - test_pos = rconf.GetPositions() - initial_tas = get_all_torsion_angles(rmol, rconf) - - conf = Conformer( - test_pos, constants.ad_smiles, constants.ad_atom_types, constants.ad_free_tas - ) - # check torsion angles randomisation - conf.randomize_freely_rotatable_tas() - conf_tas = conf.get_all_torsion_angles() - for k, v in conf_tas.items(): - if k in conf.freely_rotatable_tas: - assert not np.isclose(v, initial_tas[k]) - else: - assert np.isclose(v, initial_tas[k]) - - data = [[k, v1, v2] for (k, v1), v2 in zip(initial_tas.items(), conf_tas.values())] - print(tabulate(data, headers=["torsion angle", "initial value", "conf value"])) - - # check actions are applied - actions = ( - np.random.uniform(-1, 1, size=len(conf._dgl_graph.edges()[0]) // 2) * np.pi - ) - conf.apply_actions(actions) - new_conf_tas = conf.get_all_torsion_angles() - data = [[k, v1, v2] for (k, v1), v2 in zip(conf_tas.items(), new_conf_tas.values())] - print(tabulate(data, headers=["torsion angle", "before action", "after action"])) - actions_dict = { - ta: actions[conf.get_ta_index_in_dgl_graph(ta) // 2] - for ta in conf.freely_rotatable_tas - } - data = [[k, a, (conf_tas[k] + a), new_conf_tas[k]] for k, a in actions_dict.items()] - print( - tabulate( - data, headers=["torsion angle", "action", "init + action", "after action"] - ) - ) - - # check dgl_graph - conf.randomize_freely_rotatable_tas() - print("rdk pos", conf.get_atom_positions()[3]) - print( - "_dgl pos (should differ from rdk)", - conf._dgl_graph.ndata[constants.atom_position_name][3], - ) - print( - "dgl pos (should be the same as rdk)", - conf.dgl_graph.ndata[constants.atom_position_name][3], - ) From d3bf9f4b0750502b2cc27df0683b45327cfbc967 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Thu, 7 Sep 2023 16:23:02 -0400 Subject: [PATCH 090/100] moved install script --- setup_conformer_conda.sh => setup/conformer_conda.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename setup_conformer_conda.sh => setup/conformer_conda.sh (100%) diff --git a/setup_conformer_conda.sh b/setup/conformer_conda.sh similarity index 100% rename from setup_conformer_conda.sh rename to setup/conformer_conda.sh From e3c6aac8eaae952f26edf5bcdec36ad11c371f40 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 18 Sep 2023 18:55:40 -0400 Subject: [PATCH 091/100] fixed rotatatable bonds --- gflownet/proxy/conformers/tblite.py | 5 +- gflownet/proxy/conformers/xtb.py | 3 +- gflownet/utils/molecule/rotatable_bonds.py | 138 ++++++++++-------- gflownet/utils/molecule/torsions.py | 3 +- scripts/dav_mp20_stats.py | 1 + .../gflownet/utils/molecule/test_torsions.py | 6 +- 6 files changed, 88 insertions(+), 68 deletions(-) diff --git a/gflownet/proxy/conformers/tblite.py b/gflownet/proxy/conformers/tblite.py index 2948b4b62..867a653be 100644 --- a/gflownet/proxy/conformers/tblite.py +++ b/gflownet/proxy/conformers/tblite.py @@ -1,11 +1,10 @@ # This needs to be imported first due to conda/pip package conflicts. -from tblite.interface import Calculator - import os from typing import List import torch -from joblib import delayed, Parallel +from joblib import Parallel, delayed +from tblite.interface import Calculator from torch import Tensor from wurlitzer import pipes diff --git a/gflownet/proxy/conformers/xtb.py b/gflownet/proxy/conformers/xtb.py index dc74c2acf..ab45c23aa 100644 --- a/gflownet/proxy/conformers/xtb.py +++ b/gflownet/proxy/conformers/xtb.py @@ -13,7 +13,7 @@ import numpy as np import numpy.typing as npt import torch -from joblib import delayed, Parallel +from joblib import Parallel, delayed from torch import Tensor from wurlitzer import pipes @@ -21,7 +21,6 @@ from gflownet.proxy.conformers.tblite import _chunks from gflownet.utils.molecule.xtb_cli import run_gfn_xtb - METHODS = {"gfn2": "gfn 2", "gfnff": "gfnff"} diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index 03941f8c6..da92af8f5 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -1,65 +1,87 @@ -# Taken from https://pyxtal.readthedocs.io/en/latest/_modules/pyxtal/molecule.html. +# inspired by https://pyxtal.readthedocs.io/en/latest/_modules/pyxtal/molecule.html. from operator import itemgetter +import numpy as np +from rdkit import Chem -def find_rotor_from_smile(smile): + +def remove_duplicative_tas(tas_list): + """ + Remove duplicative torsion angles from a list of torsion angle tuples. + + Args + ---- + tas_list (list of tuples): A list of torsion angle tuples, each containing four values: + (atom1, atom2, atom3, atom4) + + Returns + ------- + list of tuples: A list of unique torsion angle tuples, where duplicative angles have been removed. + """ + tas = np.array(tas_list) + clean_tas = [] + considered = [] + for row in tas: + begin = row[1] + end = row[2] + if not (begin, end) in considered and begin < end: + duplicates = tas[np.logical_and(tas[:, 1] == begin, tas[:, 2] == end)] + duplicates = duplicates[ + np.where(duplicates[:, 0] == duplicates[:, 0].min())[0] + ] + clean_tas.append(duplicates[np.argmin(duplicates[:, 3])].tolist()) + considered.append((begin, end)) + return clean_tas + + +def get_rotatable_ta_list(mol): + """ + Find unique rotatable torsion angles of a molecule. Torsion angle is given by a tuple of adjacent atoms' + indecies (atom1, atom2, atom3, atom4), where + - atom2 < atom3 + - atom1 and atom4 are minimal among neighbours of atom2 and atom3 correspondingly + + Torsion angle is considered rotatable if: + - the bond (atom2, atom3) is a single bond + - atom1 and atom4 are not hydrogens (ignore hydrogen torsion angles) + - none of atom2 and atom3 are adjacent to a triple bond (as the bonds near the triple bonds must be fixed) + - atom2 and atom3 are not in the same ring + + Args + ---- + mol (RDKit Mol object): A molecule for which torsion angles need to be detected. + + Returns + ------- + list of tuples: A list of unique torsion angle tuples corresponding to rotatable bonds in the molecule. """ - Find the positions of rotatable bonds in the molecule. + torsion_pattern = "[*]~[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]~[*]" + substructures = Chem.MolFromSmarts(torsion_pattern) + torsion_angles = remove_duplicative_tas( + list(mol.GetSubstructMatches(substructures)) + ) + return torsion_angles + + +def find_rotor_from_smile(smile): """ + Find unique rotatable torsion angles of a molecule. Torsion angle is given by a tuple of adjacent atoms' + indecies (atom1, atom2, atom3, atom4), where + - atom2 < atom3 + - atom1 and atom4 are minimal among neighbours of atom2 and atom3 correspondingly + + Torsion angle is considered rotatable if: + - the bond (atom2, atom3) is a single bond + - atom1 and atom4 are not hydrogens (ignore hydrogen torsion angles) + - none of atom2 and atom3 are adjacent to a triple bond (as the bonds near the triple bonds must be fixed) + - atom2 and atom3 are not in the same ring - def cleaner(list_to_clean, neighbors): - """ - Remove duplicate torsion from a list of atom index tuples. - """ - - for_remove = [] - for x in reversed(range(len(list_to_clean))): - ix0 = itemgetter(0)(list_to_clean[x]) - ix3 = itemgetter(3)(list_to_clean[x]) - # for i-j-k-l, we don't want i, l are the ending members - # here C-C-S=O is not a good choice since O is only 1-coordinated - if neighbors[ix0] > 1 and neighbors[ix3] > 1: - for y in reversed(range(x)): - ix1 = itemgetter(1)(list_to_clean[x]) - ix2 = itemgetter(2)(list_to_clean[x]) - iy1 = itemgetter(1)(list_to_clean[y]) - iy2 = itemgetter(2)(list_to_clean[y]) - if [ix1, ix2] == [iy1, iy2] or [ix1, ix2] == [iy2, iy1]: - for_remove.append(y) - else: - for_remove.append(x) - clean_list = [] - for i, v in enumerate(list_to_clean): - if i not in set(for_remove): - clean_list.append(v) - return clean_list - - if smile in ["Cl-", "F-", "Br-", "I-", "Li+", "Na+"]: - return [] - else: - from rdkit import Chem - - smarts_torsion1 = "[*]~[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]~[*]" - smarts_torsion2 = "[*]~[^2]=[^2]~[*]" # C=C bonds - # smarts_torsion2="[*]~[^1]#[^1]~[*]" # C-C triples bonds, to be fixed - - mol = Chem.MolFromSmiles(smile) - N_atom = mol.GetNumAtoms() - neighbors = [len(a.GetNeighbors()) for a in mol.GetAtoms()] - # make sure that the ending members will be counted - neighbors[0] += 1 - neighbors[-1] += 1 - patn_tor1 = Chem.MolFromSmarts(smarts_torsion1) - torsion1 = cleaner(list(mol.GetSubstructMatches(patn_tor1)), neighbors) - patn_tor2 = Chem.MolFromSmarts(smarts_torsion2) - torsion2 = cleaner(list(mol.GetSubstructMatches(patn_tor2)), neighbors) - tmp = cleaner(torsion1 + torsion2, neighbors) - torsions = [] - for t in tmp: - (i, j, k, l) = t - b = mol.GetBondBetweenAtoms(j, k) - if not b.IsInRing(): - torsions.append(t) - # if len(torsions) > 6: torsions[1] = (4, 7, 10, 15) - return torsions + Parameters: + smile (str): The SMILES string representing a molecule. + + Returns: + list of tuples: A list of unique torsion angle tuples corresponding to rotatable bonds in the molecule. + """ + mol = Chem.MolFromSmiles(smile) + return get_rotatable_ta_list(mol) diff --git a/gflownet/utils/molecule/torsions.py b/gflownet/utils/molecule/torsions.py index a84146933..a32d271b6 100644 --- a/gflownet/utils/molecule/torsions.py +++ b/gflownet/utils/molecule/torsions.py @@ -1,7 +1,6 @@ -import torch import networkx as nx import numpy as np - +import torch from pytorch3d.transforms import axis_angle_to_matrix from gflownet.utils.molecule import constants diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index 2b3e7ee5d..3df1c78c9 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -20,6 +20,7 @@ from collections import Counter from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders + from gflownet.proxy.crystals.dave import DAVE from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index dba8f7ea5..acdeef1db 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -1,15 +1,14 @@ +import dgl import pytest import torch -import dgl - from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Geometry.rdGeometry import Point3D -from gflownet.utils.molecule.torsions import get_rotation_masks, apply_rotations from gflownet.utils.molecule import constants from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values +from gflownet.utils.molecule.torsions import apply_rotations, get_rotation_masks def test_four_nodes_chain(): @@ -146,6 +145,7 @@ def stress_test_apply_rotation_alanine_dipeptide(): from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Geometry.rdGeometry import Point3D + from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values From 3dda1f85fe84fa7e74e8717a116f6eaf1ec750e9 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 18 Sep 2023 18:57:44 -0400 Subject: [PATCH 092/100] add test --- tests/gflownet/utils/molecule/test_rotatable_bonds.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 tests/gflownet/utils/molecule/test_rotatable_bonds.py diff --git a/tests/gflownet/utils/molecule/test_rotatable_bonds.py b/tests/gflownet/utils/molecule/test_rotatable_bonds.py new file mode 100644 index 000000000..2395972d7 --- /dev/null +++ b/tests/gflownet/utils/molecule/test_rotatable_bonds.py @@ -0,0 +1,11 @@ +from rdkit import Chem + +from gflownet.utils.molecule import constants +from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smile + + +def test_simple_ad(): + tas = find_rotor_from_smile(constants.ad_smiles) + assert len(tas) == 4 + expected = [[0, 1, 2, 3], [0, 1, 6, 7], [1, 2, 4, 5], [1, 6, 7, 8]] + assert tas == expected From e792e49cca14ba697a48faf2dc3ce58869ef9171 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Mon, 18 Sep 2023 19:07:47 -0400 Subject: [PATCH 093/100] isort fix --- gflownet/proxy/conformers/tblite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gflownet/proxy/conformers/tblite.py b/gflownet/proxy/conformers/tblite.py index 867a653be..73d8ba037 100644 --- a/gflownet/proxy/conformers/tblite.py +++ b/gflownet/proxy/conformers/tblite.py @@ -1,10 +1,11 @@ # This needs to be imported first due to conda/pip package conflicts. +from tblite.interface import Calculator # isort: skip + import os from typing import List import torch from joblib import Parallel, delayed -from tblite.interface import Calculator from torch import Tensor from wurlitzer import pipes From 34d70a099101c37b61be283b25ce2f3bba785c8f Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Mon, 18 Sep 2023 21:32:23 -0400 Subject: [PATCH 094/100] formatting changes & typo fixes --- gflownet/utils/molecule/rotatable_bonds.py | 55 +++++++++++----------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index da92af8f5..4f082cbef 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -1,23 +1,22 @@ -# inspired by https://pyxtal.readthedocs.io/en/latest/_modules/pyxtal/molecule.html. - -from operator import itemgetter +# Inspired by https://pyxtal.readthedocs.io/en/latest/_modules/pyxtal/molecule.html. import numpy as np from rdkit import Chem -def remove_duplicative_tas(tas_list): +def remove_duplicate_tas(tas_list): """ - Remove duplicative torsion angles from a list of torsion angle tuples. + Remove duplicate torsion angles from a list of torsion angle tuples. Args ---- - tas_list (list of tuples): A list of torsion angle tuples, each containing four values: - (atom1, atom2, atom3, atom4) + tas_list : list of tuples + A list of torsion angle tuples, each containing four values: + (atom1, atom2, atom3, atom4). Returns ------- - list of tuples: A list of unique torsion angle tuples, where duplicative angles have been removed. + list of tuples: A list of unique torsion angle tuples, where duplicate angles have been removed. """ tas = np.array(tas_list) clean_tas = [] @@ -38,19 +37,20 @@ def remove_duplicative_tas(tas_list): def get_rotatable_ta_list(mol): """ Find unique rotatable torsion angles of a molecule. Torsion angle is given by a tuple of adjacent atoms' - indecies (atom1, atom2, atom3, atom4), where - - atom2 < atom3 - - atom1 and atom4 are minimal among neighbours of atom2 and atom3 correspondingly + indices (atom1, atom2, atom3, atom4), where: + - atom2 < atom3, + - atom1 and atom4 are minimal among neighbours of atom2 and atom3 correspondingly. Torsion angle is considered rotatable if: - - the bond (atom2, atom3) is a single bond - - atom1 and atom4 are not hydrogens (ignore hydrogen torsion angles) - - none of atom2 and atom3 are adjacent to a triple bond (as the bonds near the triple bonds must be fixed) - - atom2 and atom3 are not in the same ring + - the bond (atom2, atom3) is a single bond, + - atom1 and atom4 are not hydrogens (ignore hydrogen torsion angles), + - none of atom2 and atom3 are adjacent to a triple bond (as the bonds near the triple bonds must be fixed), + - atom2 and atom3 are not in the same ring. Args ---- - mol (RDKit Mol object): A molecule for which torsion angles need to be detected. + mol : RDKit Mol object + A molecule for which torsion angles need to be detected. Returns ------- @@ -58,30 +58,31 @@ def get_rotatable_ta_list(mol): """ torsion_pattern = "[*]~[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]~[*]" substructures = Chem.MolFromSmarts(torsion_pattern) - torsion_angles = remove_duplicative_tas( + torsion_angles = remove_duplicate_tas( list(mol.GetSubstructMatches(substructures)) ) return torsion_angles -def find_rotor_from_smile(smile): +def find_rotor_from_smiles(smiles): """ Find unique rotatable torsion angles of a molecule. Torsion angle is given by a tuple of adjacent atoms' - indecies (atom1, atom2, atom3, atom4), where - - atom2 < atom3 - - atom1 and atom4 are minimal among neighbours of atom2 and atom3 correspondingly + indices (atom1, atom2, atom3, atom4), where: + - atom2 < atom3, + - atom1 and atom4 are minimal among neighbours of atom2 and atom3 correspondingly. Torsion angle is considered rotatable if: - - the bond (atom2, atom3) is a single bond - - atom1 and atom4 are not hydrogens (ignore hydrogen torsion angles) - - none of atom2 and atom3 are adjacent to a triple bond (as the bonds near the triple bonds must be fixed) - - atom2 and atom3 are not in the same ring + - the bond (atom2, atom3) is a single bond, + - atom1 and atom4 are not hydrogens (ignore hydrogen torsion angles), + - none of atom2 and atom3 are adjacent to a triple bond (as the bonds near the triple bonds must be fixed), + - atom2 and atom3 are not in the same ring. Parameters: - smile (str): The SMILES string representing a molecule. + smiles : str + The SMILES string representing a molecule. Returns: list of tuples: A list of unique torsion angle tuples corresponding to rotatable bonds in the molecule. """ - mol = Chem.MolFromSmiles(smile) + mol = Chem.MolFromSmiles(smiles) return get_rotatable_ta_list(mol) From dc7233ef2b252c3114a334e76b6511f3baaafc6b Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Mon, 18 Sep 2023 21:34:34 -0400 Subject: [PATCH 095/100] black --- gflownet/utils/molecule/rotatable_bonds.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index 4f082cbef..1f4e7f970 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -58,9 +58,7 @@ def get_rotatable_ta_list(mol): """ torsion_pattern = "[*]~[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]~[*]" substructures = Chem.MolFromSmarts(torsion_pattern) - torsion_angles = remove_duplicate_tas( - list(mol.GetSubstructMatches(substructures)) - ) + torsion_angles = remove_duplicate_tas(list(mol.GetSubstructMatches(substructures))) return torsion_angles From 4384c4e37e4dedf297c477a445f2723aa7926197 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Tue, 19 Sep 2023 12:39:55 -0400 Subject: [PATCH 096/100] fixed function name --- gflownet/envs/conformers/conformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/envs/conformers/conformer.py b/gflownet/envs/conformers/conformer.py index 821222ffd..7a7d458ae 100644 --- a/gflownet/envs/conformers/conformer.py +++ b/gflownet/envs/conformers/conformer.py @@ -11,7 +11,7 @@ from gflownet.utils.molecule.constants import ad_atom_types from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.rdkit_conformer import RDKitConformer -from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smile +from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles class Conformer(ContinuousTorus): @@ -65,7 +65,7 @@ def _get_positions(smiles: str) -> npt.NDArray: @staticmethod def _get_torsion_angles(smiles: str, indices: List[int]) -> List[Tuple[int]]: - torsion_angles = find_rotor_from_smile(smiles) + torsion_angles = find_rotor_from_smiles(smiles) torsion_angles = [torsion_angles[i] for i in indices] return torsion_angles From 00f614a2c18ef65b152613380a55e738887aeb24 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 19 Sep 2023 16:17:56 -0400 Subject: [PATCH 097/100] add hydrogens fix --- gflownet/utils/molecule/rotatable_bonds.py | 3 +-- tests/gflownet/utils/molecule/test_rotatable_bonds.py | 9 +++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index 1f4e7f970..7625a7b48 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -43,7 +43,6 @@ def get_rotatable_ta_list(mol): Torsion angle is considered rotatable if: - the bond (atom2, atom3) is a single bond, - - atom1 and atom4 are not hydrogens (ignore hydrogen torsion angles), - none of atom2 and atom3 are adjacent to a triple bond (as the bonds near the triple bonds must be fixed), - atom2 and atom3 are not in the same ring. @@ -71,7 +70,6 @@ def find_rotor_from_smiles(smiles): Torsion angle is considered rotatable if: - the bond (atom2, atom3) is a single bond, - - atom1 and atom4 are not hydrogens (ignore hydrogen torsion angles), - none of atom2 and atom3 are adjacent to a triple bond (as the bonds near the triple bonds must be fixed), - atom2 and atom3 are not in the same ring. @@ -83,4 +81,5 @@ def find_rotor_from_smiles(smiles): list of tuples: A list of unique torsion angle tuples corresponding to rotatable bonds in the molecule. """ mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) return get_rotatable_ta_list(mol) diff --git a/tests/gflownet/utils/molecule/test_rotatable_bonds.py b/tests/gflownet/utils/molecule/test_rotatable_bonds.py index 2395972d7..f101cd27b 100644 --- a/tests/gflownet/utils/molecule/test_rotatable_bonds.py +++ b/tests/gflownet/utils/molecule/test_rotatable_bonds.py @@ -1,11 +1,12 @@ from rdkit import Chem from gflownet.utils.molecule import constants -from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smile +from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles def test_simple_ad(): - tas = find_rotor_from_smile(constants.ad_smiles) - assert len(tas) == 4 - expected = [[0, 1, 2, 3], [0, 1, 6, 7], [1, 2, 4, 5], [1, 6, 7, 8]] + tas = find_rotor_from_smiles(constants.ad_smiles) + assert len(tas) == 7 + expected = [[0, 1, 2, 3], [0, 1, 6, 7], [1, 2, 4, 5], [1, 6, 7, 8], + [2, 4, 5, 15], [6, 7, 9, 19], [10, 0, 1, 13]] assert tas == expected From 0ae2a8ab75d5e898b1a88b82f286fa0c2f362553 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 19 Sep 2023 18:05:23 -0400 Subject: [PATCH 098/100] fix ordering bug, add check for hydrogen ta --- gflownet/utils/molecule/rotatable_bonds.py | 21 ++++++++++++++++ .../utils/molecule/test_rotatable_bonds.py | 24 +++++++++++++++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index 7625a7b48..c6424449e 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -26,6 +26,10 @@ def remove_duplicate_tas(tas_list): end = row[2] if not (begin, end) in considered and begin < end: duplicates = tas[np.logical_and(tas[:, 1] == begin, tas[:, 2] == end)] + duplicates_reversed = tas[np.logical_and(tas[:, 2] == begin, tas[:, 1] == end)] + duplicates_reversed = np.flip(duplicates_reversed, axis=1) + duplicates = np.concatenate([duplicates, duplicates_reversed], axis=0) + assert duplicates.shape[-1] == 4 duplicates = duplicates[ np.where(duplicates[:, 0] == duplicates[:, 0].min())[0] ] @@ -83,3 +87,20 @@ def find_rotor_from_smiles(smiles): mol = Chem.MolFromSmiles(smiles) mol = Chem.AddHs(mol) return get_rotatable_ta_list(mol) + +def is_hydrogen_ta(mol, ta): + """ + Simple check whether the given torsion angle is 'hydrogen torsion angle', i.e. + it effectively influences only positions of some hydrogens in the molecule + """ + def is_connected_to_three_hydrogens(mol, atom_id, except_id): + atom = mol.GetAtomWithIdx(atom_id) + neigh_numbers = [] + for n in atom.GetNeighbors(): + if n.GetIdx() != except_id: + neigh_numbers.append(n.GetAtomicNum()) + neigh_numbers = np.array(neigh_numbers) + return np.all(neigh_numbers == 1) + first = is_connected_to_three_hydrogens(mol, ta[1], ta[2]) + second = is_connected_to_three_hydrogens(mol, ta[2], ta[1]) + return first or second \ No newline at end of file diff --git a/tests/gflownet/utils/molecule/test_rotatable_bonds.py b/tests/gflownet/utils/molecule/test_rotatable_bonds.py index f101cd27b..100e90a99 100644 --- a/tests/gflownet/utils/molecule/test_rotatable_bonds.py +++ b/tests/gflownet/utils/molecule/test_rotatable_bonds.py @@ -1,12 +1,32 @@ +import pytest from rdkit import Chem from gflownet.utils.molecule import constants -from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles +from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles, is_hydrogen_ta def test_simple_ad(): tas = find_rotor_from_smiles(constants.ad_smiles) assert len(tas) == 7 expected = [[0, 1, 2, 3], [0, 1, 6, 7], [1, 2, 4, 5], [1, 6, 7, 8], - [2, 4, 5, 15], [6, 7, 9, 19], [10, 0, 1, 13]] + [2, 4, 5, 15], [6, 7, 9, 19], [10, 0, 1, 2]] assert tas == expected + +@pytest.mark.parametrize( + 'ta, expected_flag', + [ + ([0, 1, 2, 3], False), + ([0, 1, 6, 7], False), + ([1, 2, 4, 5], False), + ([1, 6, 7, 8], False), + ([2, 4, 5, 15], True), + ([6, 7, 9, 19], True), + ([10, 0, 1, 2], True) + ] +) +def test_is_hydrogen_ta(ta, expected_flag): + mol = Chem.MolFromSmiles(constants.ad_smiles) + mol = Chem.AddHs(mol) + assert is_hydrogen_ta(mol, ta) == expected_flag + + From 2875505929dce3c01c473dc588f4dd109ed19dd8 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 19 Sep 2023 23:03:27 -0400 Subject: [PATCH 099/100] fix another bug --- gflownet/utils/molecule/rotatable_bonds.py | 4 +++- tests/gflownet/utils/molecule/test_rotatable_bonds.py | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index c6424449e..cb1743fb1 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -24,7 +24,9 @@ def remove_duplicate_tas(tas_list): for row in tas: begin = row[1] end = row[2] - if not (begin, end) in considered and begin < end: + if not (begin, end) in considered and not (end, begin) in considered: + if begin > end: + begin, end = end, begin duplicates = tas[np.logical_and(tas[:, 1] == begin, tas[:, 2] == end)] duplicates_reversed = tas[np.logical_and(tas[:, 2] == begin, tas[:, 1] == end)] duplicates_reversed = np.flip(duplicates_reversed, axis=1) diff --git a/tests/gflownet/utils/molecule/test_rotatable_bonds.py b/tests/gflownet/utils/molecule/test_rotatable_bonds.py index 100e90a99..f3ed2e88e 100644 --- a/tests/gflownet/utils/molecule/test_rotatable_bonds.py +++ b/tests/gflownet/utils/molecule/test_rotatable_bonds.py @@ -9,7 +9,7 @@ def test_simple_ad(): tas = find_rotor_from_smiles(constants.ad_smiles) assert len(tas) == 7 expected = [[0, 1, 2, 3], [0, 1, 6, 7], [1, 2, 4, 5], [1, 6, 7, 8], - [2, 4, 5, 15], [6, 7, 9, 19], [10, 0, 1, 2]] + [10, 0, 1, 2], [2, 4, 5, 15], [6, 7, 9, 19]] assert tas == expected @pytest.mark.parametrize( @@ -29,4 +29,10 @@ def test_is_hydrogen_ta(ta, expected_flag): mol = Chem.AddHs(mol) assert is_hydrogen_ta(mol, ta) == expected_flag +def test_number_tas(): + smiles = 'CCCc1nnc(NC(=O)COc2ccc3c(c2)OCO3)s1' + expected = 8 + tas = find_rotor_from_smiles(smiles) + assert len(tas) == expected + From 9b45f15a3f3f431fbdf2ff199204e0f40a5bb1a9 Mon Sep 17 00:00:00 2001 From: michalkoziarski Date: Wed, 20 Sep 2023 18:59:53 -0400 Subject: [PATCH 100/100] black & isort --- gflownet/utils/molecule/rotatable_bonds.py | 9 +++- .../utils/molecule/test_rotatable_bonds.py | 44 ++++++++++++------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/gflownet/utils/molecule/rotatable_bonds.py b/gflownet/utils/molecule/rotatable_bonds.py index cb1743fb1..4dc73a590 100644 --- a/gflownet/utils/molecule/rotatable_bonds.py +++ b/gflownet/utils/molecule/rotatable_bonds.py @@ -28,7 +28,9 @@ def remove_duplicate_tas(tas_list): if begin > end: begin, end = end, begin duplicates = tas[np.logical_and(tas[:, 1] == begin, tas[:, 2] == end)] - duplicates_reversed = tas[np.logical_and(tas[:, 2] == begin, tas[:, 1] == end)] + duplicates_reversed = tas[ + np.logical_and(tas[:, 2] == begin, tas[:, 1] == end) + ] duplicates_reversed = np.flip(duplicates_reversed, axis=1) duplicates = np.concatenate([duplicates, duplicates_reversed], axis=0) assert duplicates.shape[-1] == 4 @@ -90,11 +92,13 @@ def find_rotor_from_smiles(smiles): mol = Chem.AddHs(mol) return get_rotatable_ta_list(mol) + def is_hydrogen_ta(mol, ta): """ Simple check whether the given torsion angle is 'hydrogen torsion angle', i.e. it effectively influences only positions of some hydrogens in the molecule """ + def is_connected_to_three_hydrogens(mol, atom_id, except_id): atom = mol.GetAtomWithIdx(atom_id) neigh_numbers = [] @@ -103,6 +107,7 @@ def is_connected_to_three_hydrogens(mol, atom_id, except_id): neigh_numbers.append(n.GetAtomicNum()) neigh_numbers = np.array(neigh_numbers) return np.all(neigh_numbers == 1) + first = is_connected_to_three_hydrogens(mol, ta[1], ta[2]) second = is_connected_to_three_hydrogens(mol, ta[2], ta[1]) - return first or second \ No newline at end of file + return first or second diff --git a/tests/gflownet/utils/molecule/test_rotatable_bonds.py b/tests/gflownet/utils/molecule/test_rotatable_bonds.py index f3ed2e88e..b316dc1c0 100644 --- a/tests/gflownet/utils/molecule/test_rotatable_bonds.py +++ b/tests/gflownet/utils/molecule/test_rotatable_bonds.py @@ -2,37 +2,47 @@ from rdkit import Chem from gflownet.utils.molecule import constants -from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles, is_hydrogen_ta +from gflownet.utils.molecule.rotatable_bonds import ( + find_rotor_from_smiles, + is_hydrogen_ta, +) def test_simple_ad(): tas = find_rotor_from_smiles(constants.ad_smiles) assert len(tas) == 7 - expected = [[0, 1, 2, 3], [0, 1, 6, 7], [1, 2, 4, 5], [1, 6, 7, 8], - [10, 0, 1, 2], [2, 4, 5, 15], [6, 7, 9, 19]] + expected = [ + [0, 1, 2, 3], + [0, 1, 6, 7], + [1, 2, 4, 5], + [1, 6, 7, 8], + [10, 0, 1, 2], + [2, 4, 5, 15], + [6, 7, 9, 19], + ] assert tas == expected + @pytest.mark.parametrize( - 'ta, expected_flag', - [ - ([0, 1, 2, 3], False), - ([0, 1, 6, 7], False), - ([1, 2, 4, 5], False), - ([1, 6, 7, 8], False), - ([2, 4, 5, 15], True), - ([6, 7, 9, 19], True), - ([10, 0, 1, 2], True) - ] + "ta, expected_flag", + [ + ([0, 1, 2, 3], False), + ([0, 1, 6, 7], False), + ([1, 2, 4, 5], False), + ([1, 6, 7, 8], False), + ([2, 4, 5, 15], True), + ([6, 7, 9, 19], True), + ([10, 0, 1, 2], True), + ], ) def test_is_hydrogen_ta(ta, expected_flag): mol = Chem.MolFromSmiles(constants.ad_smiles) mol = Chem.AddHs(mol) assert is_hydrogen_ta(mol, ta) == expected_flag + def test_number_tas(): - smiles = 'CCCc1nnc(NC(=O)COc2ccc3c(c2)OCO3)s1' + smiles = "CCCc1nnc(NC(=O)COc2ccc3c(c2)OCO3)s1" expected = 8 tas = find_rotor_from_smiles(smiles) - assert len(tas) == expected - - + assert len(tas) == expected