Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EGNN policy #181

Open
wants to merge 78 commits into
base: molecule_graph_env
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
9bb257a
edge-level EGNN
michalkoziarski Aug 25, 2023
da86d66
slightly modified batching
michalkoziarski Aug 25, 2023
7a1d818
increased EGNN depth
michalkoziarski Aug 25, 2023
739f5ea
creating smaller core graph
michalkoziarski Aug 25, 2023
a802d32
alanine dipeptide with torchani and EGNN config
michalkoziarski Aug 25, 2023
8c4115a
device
michalkoziarski Aug 25, 2023
e21f6bc
shorter trajectory length
michalkoziarski Aug 25, 2023
604d897
shorter trajectory length
michalkoziarski Aug 25, 2023
b32ef44
speed improvement
michalkoziarski Aug 25, 2023
37aa707
Merge branch 'main' into egnn_edge_policy
michalkoziarski Aug 30, 2023
da02c61
Merge branch 'rfgfn_test_update' into egnn_edge_policy
michalkoziarski Aug 30, 2023
4973b2c
fixed test
michalkoziarski Aug 30, 2023
5118c3b
black
michalkoziarski Aug 30, 2023
8c19261
isort
michalkoziarski Aug 30, 2023
d79ea88
isort; removing unused imports
michalkoziarski Aug 30, 2023
1e7c9bf
DGL added to CI
michalkoziarski Aug 30, 2023
ef35885
updated configs
michalkoziarski Aug 30, 2023
d1bc4e2
support for separate mlp heads for each torsion angle
michalkoziarski Aug 31, 2023
2721251
support for fake edges
michalkoziarski Aug 31, 2023
fe12fb8
updated config
michalkoziarski Aug 31, 2023
d67bf86
hydrogen removal for gnn policy
michalkoziarski Sep 4, 2023
a57a8b7
Merge branch 'main' into egnn_edge_policy
michalkoziarski Sep 6, 2023
684058a
Update test of backward sampling.
alexhernandezgarcia Sep 6, 2023
3995536
Small change of comment.
alexhernandezgarcia Sep 6, 2023
ebc12a3
Implement test to catch source of bug - currently not fixed. Skip som…
alexhernandezgarcia Sep 6, 2023
f2e5c64
Merge branch 'main' into egnn_edge_policy
michalkoziarski Sep 7, 2023
08b9839
Merge remote-tracking branch 'origin/backward-sampling-continuous' in…
michalkoziarski Sep 7, 2023
dce789a
added replay buffer (this time for real)
michalkoziarski Sep 7, 2023
442853b
using fake edges and separate MLPs per torsion
michalkoziarski Sep 7, 2023
809e345
Merge main and resolve conflicts
alexhernandezgarcia Sep 7, 2023
06e2d38
Merge branch 'backward-sampling-continuous' into fix-bws-crystal
alexhernandezgarcia Sep 7, 2023
d8e6050
Merge remote-tracking branch 'origin/backward-sampling-continuous' in…
michalkoziarski Sep 7, 2023
108ea47
set_state fix
michalkoziarski Sep 7, 2023
aced258
set_state extended to use done for sub-envs
michalkoziarski Sep 7, 2023
a349e19
Merge branch 'molecule_graph_env' into egnn_edge_policy
michalkoziarski Sep 7, 2023
a7dc3e0
Merge pull request #198 from alexhernandezgarcia/fix-bws-crystal-mk
michalkoziarski Sep 7, 2023
afcd5d6
Merge branch 'molecule_graph_env' into egnn_edge_policy
michalkoziarski Sep 7, 2023
bebccb9
Merge branch 'fix-bws-crystal' of https://github.com/alexhernandezgar…
michalkoziarski Sep 8, 2023
089969b
lattice system synchronization
michalkoziarski Sep 8, 2023
4ab9935
docstring change
michalkoziarski Sep 8, 2023
6d8b8ef
docstring change
michalkoziarski Sep 8, 2023
f567eb8
Merge pull request #203 from alexhernandezgarcia/fix-bws-crystal-mk-2
michalkoziarski Sep 8, 2023
b2d86a4
using ibuprofen instead of ADP
michalkoziarski Sep 11, 2023
cedec5e
reduced lr_z_mult
michalkoziarski Sep 11, 2023
e4c51af
actually ibuprofen
michalkoziarski Sep 11, 2023
4cbafdd
Merge branch 'fix-bws-crystal' of github.com:alexhernandezgarcia/gflo…
alexhernandezgarcia Sep 11, 2023
0396281
added S to atom types
michalkoziarski Sep 11, 2023
f08c4b3
Re-enable some tests.
alexhernandezgarcia Sep 11, 2023
522a31d
Control number of repetitions and batch size with global variables an…
alexhernandezgarcia Sep 11, 2023
69ff0ed
N_REPETITIONS <- N_REPEATS
alexhernandezgarcia Sep 11, 2023
4c76175
Add dev notes - kind of dummy changes to re-launch failed tests.
alexhernandezgarcia Sep 11, 2023
f12a057
hotfix to enable sweeps
michalkoziarski Sep 11, 2023
7a4cf06
added additional molecules
michalkoziarski Sep 12, 2023
1cc996d
shared policy config
michalkoziarski Sep 12, 2023
27e17e9
parse_policy_config moved to utils; backward compatibility fix
michalkoziarski Sep 12, 2023
0bf49fc
fixed tests
michalkoziarski Sep 12, 2023
4bd18c1
Merge branch 'shared_policy_config' into egnn_edge_policy
michalkoziarski Sep 12, 2023
6581504
reverted lr_z_mult
michalkoziarski Sep 13, 2023
9020ac0
configs
michalkoziarski Sep 13, 2023
b753b5f
shared policy config update
michalkoziarski Sep 13, 2023
d8cc130
isort
carriepl-mila Sep 13, 2023
97b0997
Merge pull request #195 from alexhernandezgarcia/fix-bws-crystal
alexhernandezgarcia Sep 13, 2023
ba4d7e1
Merge pull request #205 from alexhernandezgarcia/shared_policy_config
michalkoziarski Sep 13, 2023
699bee7
Merge branch 'main' of https://github.com/alexhernandezgarcia/gflownet
michalkoziarski Sep 14, 2023
b6930a6
Merge branch 'main' into egnn_edge_policy
michalkoziarski Sep 15, 2023
6e45a5a
whitespace
michalkoziarski Sep 15, 2023
e3c6aac
fixed rotatatable bonds
AlexandraVolokhova Sep 18, 2023
3dda1f8
add test
AlexandraVolokhova Sep 18, 2023
e792e49
isort fix
michalkoziarski Sep 18, 2023
34d70a0
formatting changes & typo fixes
michalkoziarski Sep 19, 2023
dc7233e
black
michalkoziarski Sep 19, 2023
270f36d
support for using all available TAs
michalkoziarski Sep 19, 2023
8c34a23
added Cl
michalkoziarski Sep 19, 2023
11044a9
Merge branch 'torsion_angles_detection_fix_mk' into egnn_edge_policy
michalkoziarski Sep 19, 2023
1a4f7c3
fixed function name
michalkoziarski Sep 19, 2023
5db4150
increased default backward replay size
michalkoziarski Sep 19, 2023
706263d
added F
michalkoziarski Sep 19, 2023
a349dea
added triple bond type
michalkoziarski Sep 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config/env/conformers/conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ _target_: gflownet.envs.conformers.conformer.Conformer
smiles: 'O=C(c1ccc2n1CCC2C(=O)O)c3ccccc3' # ketorolac
# smiles: 'CCCCCC1=CC(=C(C(=C1)O)C2C=C(CCC2C(=C)C)C)O' # cannabidiol
# smiles: 'CN1C2CCC1C(C(C2)OC(=O)C3=CC=CC=C3)C(=O)OC' # cocaine
# smiles: 'Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C' # imatinib
# smiles: 'CC(C)c4nc(CN(C)C(=O)N[C@@H](C(C)C)C(=O)N[C@@H](Cc1ccccc1)C[C@H](O)[C@H](Cc2ccccc2)NC(=O)OCc3cncs3)cs4' # ritonavir
n_torsion_angles: 2
reward_sampling_method: nested

Expand Down
58 changes: 58 additions & 0 deletions config/experiments/iclr23/egnn_torchani.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# @package _global_

defaults:
- override /env: conformers/conformer
- override /gflownet: trajectorybalance
- override /policy: conformers/egnn
- override /proxy: conformers/torchani
- override /logger: wandb

# Environment
env:
smiles: CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O
n_torsion_angles: 2
remove_hs: True
length_traj: 5
policy_encoding_dim_per_angle: 10
policy_type: gnn
n_comp: 5
vonmises_min_concentration: 4
reward_func: boltzmann
reward_beta: 32
reward_sampling_method: nested
buffer:
replay_capacity: 1000

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
replay_sampling: weighted
optimizer:
batch_size:
forward: 90
backward_dataset: 0
backward_replay: 10
lr: 0.0001
z_dim: 16
lr_z_mult: 1000
n_train_steps: 40000
lr_decay_period: 1000000

# WandB
logger:
lightweight: True
project_name: "gflownet"
tags:
- gflownet
- continuous
- molecule
test:
period: 2000
n: 10000
checkpoints:
period: 2000

# Hydra
hydra:
run:
dir: ${user.logdir.root}/molecule/${now:%Y-%m-%d_%H-%M-%S}
57 changes: 57 additions & 0 deletions config/experiments/iclr23/mlp_torchani.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# @package _global_

defaults:
- override /env: conformers/conformer
- override /gflownet: trajectorybalance
- override /policy: conformers/mlp
- override /proxy: conformers/torchani
- override /logger: wandb

# Environment
env:
smiles: Cc1ccc(cc1Nc2nccc(n2)c3cccnc3)NC(=O)c4ccc(cc4)CN5CCN(CC5)C
n_torsion_angles: 5
length_traj: 5
policy_encoding_dim_per_angle: 10
policy_type: mlp
n_comp: 5
vonmises_min_concentration: 4
reward_func: boltzmann
reward_beta: 32
reward_sampling_method: nested
buffer:
replay_capacity: 1000

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
replay_sampling: weighted
optimizer:
batch_size:
forward: 80
backward_dataset: 0
backward_replay: 20
lr: 0.0001
z_dim: 16
lr_z_mult: 1000
n_train_steps: 40000
lr_decay_period: 1000000

# WandB
logger:
lightweight: True
project_name: "gflownet"
tags:
- gflownet
- continuous
- molecule
test:
period: 2000
n: 10000
checkpoints:
period: 2000

# Hydra
hydra:
run:
dir: ${user.logdir.root}/molecule/${now:%Y-%m-%d_%H-%M-%S}
7 changes: 3 additions & 4 deletions config/experiments/tree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ gflownet:

# MLP policy
policy:
forward:
n_hid: 256
n_layers: 3
backward:
shared:
type: mlp
n_hid: 256
n_layers: 3
forward: null
backward:
shared_weights: False

# WandB
Expand Down
20 changes: 8 additions & 12 deletions config/policy/conformers/egnn.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
_target_: gflownet.policy.conformers.egnn.EGNNPolicy

forward:
shared:
n_gnn_layers: 7
n_node_mlp_layers: 2
n_pool_mlp_layers: 2
n_mlp_layers: 2
egnn_hidden_dim: 128
node_mlp_hidden_dim: 128
pool_mlp_hidden_dim: 128
backward:
n_gnn_layers: 7
n_node_mlp_layers: 2
n_pool_mlp_layers: 2
egnn_hidden_dim: 128
node_mlp_hidden_dim: 128
pool_mlp_hidden_dim: 128
mlp_hidden_dim: 256
separate_mlp_per_torsion: True
use_fake_edges: True
fake_edge_radius: 2.0
forward: null
backward: null
6 changes: 2 additions & 4 deletions config/policy/conformers/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
_target_: gflownet.policy.base.Policy

forward:
shared:
type: mlp
n_hid: 512
n_layers: 5
forward:
checkpoint: forward
backward:
type: mlp
n_hid: 512
n_layers: 5
shared_weights: False
checkpoint: backward
2 changes: 2 additions & 0 deletions config/policy/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
_target_: gflownet.policy.base.Policy

shared: null

forward:
type: mlp
n_hid: 128
Expand Down
3 changes: 1 addition & 2 deletions gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ def step_backwards(
return self.state, action, False
parents, parents_a = self.get_parents()
state_next = parents[parents_a.index(action)]
self.state = state_next
self.done = False
self.set_state(state_next, done=False)
self.n_actions += 1
return self.state, action, True

Expand Down
37 changes: 31 additions & 6 deletions gflownet/envs/conformers/conformer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
from typing import List, Optional, Tuple

import dgl
import numpy as np
import numpy.typing as npt
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from torchtyping import TensorType
Expand All @@ -11,7 +13,7 @@
from gflownet.utils.molecule.constants import ad_atom_types
from gflownet.utils.molecule.featurizer import MolDGLFeaturizer
from gflownet.utils.molecule.rdkit_conformer import RDKitConformer
from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smile
from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles


class Conformer(ContinuousTorus):
Expand All @@ -26,13 +28,16 @@ def __init__(
n_torsion_angles: Optional[int] = 2,
torsion_indices: Optional[List[int]] = None,
policy_type: str = "mlp",
remove_hs: bool = True,
**kwargs,
):
if torsion_indices is None:
# We hard code default torsion indices for Alanine Dipeptide to preserve
# backward compatibility.
if smiles == "CC(C(=O)NC)NC(=O)C" and n_torsion_angles == 2:
torsion_indices = [2, 1]
elif n_torsion_angles == -1:
torsion_indices = None
else:
torsion_indices = list(range(n_torsion_angles))

Expand All @@ -50,12 +55,27 @@ def __init__(
f"Unrecognized policy_type = {policy_type}, expected either 'mlp' or 'gnn'."
)

self.graph = MolDGLFeaturizer(ad_atom_types).mol2dgl(self.conformer.rdk_mol)
# TODO: use DGL conformer instead
rotatable_edges = [ta[1:3] for ta in torsion_angles]
for i in range(self.graph.num_edges()):
if (
self.graph.edges()[0][i].item(),
self.graph.edges()[1][i].item(),
) not in rotatable_edges:
self.graph.edata["rotatable_edges"][i] = False

# Hydrogen removal
self.remove_hs = remove_hs
self.hs = torch.where(self.graph.ndata["atom_features"][:, 0] == 1)[0]
self.non_hs = torch.where(self.graph.ndata["atom_features"][:, 0] != 1)[0]
if remove_hs:
self.graph = dgl.remove_nodes(self.graph, self.hs)

super().__init__(n_dim=len(self.conformer.freely_rotatable_tas), **kwargs)

self.sync_conformer_with_state()

self.graph = MolDGLFeaturizer(ad_atom_types).mol2dgl(self.conformer.rdk_mol)

@staticmethod
def _get_positions(smiles: str) -> npt.NDArray:
mol = Chem.MolFromSmiles(smiles)
Expand All @@ -64,9 +84,12 @@ def _get_positions(smiles: str) -> npt.NDArray:
return mol.GetConformer().GetPositions()

@staticmethod
def _get_torsion_angles(smiles: str, indices: List[int]) -> List[Tuple[int]]:
torsion_angles = find_rotor_from_smile(smiles)
torsion_angles = [torsion_angles[i] for i in indices]
def _get_torsion_angles(
smiles: str, indices: Optional[List[int]]
) -> List[Tuple[int]]:
torsion_angles = find_rotor_from_smiles(smiles)
if indices is not None:
torsion_angles = [torsion_angles[i] for i in indices]
return torsion_angles

def sync_conformer_with_state(self, state: List = None):
Expand Down Expand Up @@ -109,6 +132,8 @@ def statebatch2policy_gnn(self, states: List[List]) -> npt.NDArray[np.float32]:
for state in states:
conformer = self.sync_conformer_with_state(state)
positions = conformer.get_atom_positions()
if self.remove_hs:
positions = positions[self.non_hs]
policy_input.append(
np.concatenate(
[positions, np.full((positions.shape[0], 1), state[-1])],
Expand Down
35 changes: 35 additions & 0 deletions gflownet/envs/crystals/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,41 @@ def statetorch2oracle(
dim=1,
)

def set_state(self, state: List, done: Optional[bool] = False):
super().set_state(state, done)

stage = self._get_stage(state)

composition_done = stage in [Stage.SPACE_GROUP, Stage.LATTICE_PARAMETERS]
space_group_done = stage == Stage.LATTICE_PARAMETERS
lattice_parameters_done = done

self.composition.set_state(self._get_composition_state(state), composition_done)
self.space_group.set_state(self._get_space_group_state(state), space_group_done)
self.lattice_parameters.set_state(
self._get_lattice_parameters_state(state), lattice_parameters_done
)

"""
We synchronize LatticeParameter's lattice system with the one of SpaceGroup
(if it was set) or reset it to the default triclinic otherwise. Why this is
needed:
1) the first case is necessary for backward sampling, where we start from
an arbitrary terminal state, and need to synchronize the LatticeParameter's
lattice system to what that state indicates,
2) the second case is also necessary in backward sampling, but when we
transition from Stage.LATTICE_PARAMETERS to Stage.SPACE_GROUP. We then need
to reset the lattice system to the default triclinic, such that its
source is back to the original one, and corresponds to the source of the
general Crystal environment.
"""
lattice_system = self.space_group.lattice_system
if lattice_system != "None":
self.lattice_parameters.lattice_system = lattice_system
else:
self.lattice_parameters.lattice_system = TRICLINIC
self.lattice_parameters._set_source()

def state2readable(self, state: Optional[List[int]] = None) -> str:
if state is None:
state = self.state
Expand Down
3 changes: 2 additions & 1 deletion gflownet/envs/crystals/spacegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def get_mask_invalid_actions_forward(
mask[-1] = False
return mask
state_type = self.get_state_type(state)
# No constraints if neither crystal-lattice system nor point symmetry selected
# If neither crystal-lattice system nor point symmetry selected, apply only
# composition-compatibility constraints
if cls_idx == 0 and ps_idx == 0:
crystal_lattice_systems = [
(self.cls_idx, idx + 1, state_type)
Expand Down
Loading