From 0bfa82dc29d26ab18bf96cbb14fb20a352361756 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 4 Jul 2024 12:49:51 -0400 Subject: [PATCH 1/2] Restricted sampling: fix and implement passing only name of config --- scripts/eval_gflownet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/eval_gflownet.py b/scripts/eval_gflownet.py index fcb509d4..cae83cff 100644 --- a/scripts/eval_gflownet.py +++ b/scripts/eval_gflownet.py @@ -14,6 +14,8 @@ sys.path.append(str(Path(__file__).resolve().parent.parent)) +from hydra.utils import instantiate + from gflownet.gflownet import GFlowNetAgent from gflownet.utils.common import load_gflow_net_from_run_path, read_hydra_config from gflownet.utils.policy import parse_policy_config @@ -202,11 +204,13 @@ def main(args): # ------------------------------------------ # Read conditional environment config, if provided - # TODO: implement allow passing just name of config if args.conditional_env_config_path is not None: - config_cond_env = read_hydra_config( - config_name=args.conditional_env_config_path - ) + conditional_env_config_path = Path(args.conditional_env_config_path) + if conditional_env_config_path.parent == Path("."): + conditional_env_config_path = ( + Path(args.run_path) / ".hydra" / conditional_env_config_path.name + ) + config_cond_env = read_hydra_config(config_name=conditional_env_config_path) if "env" in config_cond_env: config_cond_env = config_cond_env.env env_cond = instantiate( From a924e168d06259281cd3119ba42913807c0a6e4a Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 4 Jul 2024 12:50:49 -0400 Subject: [PATCH 2/2] Restricted sampling: configuration files --- .../crystals/starling_fe_restricted_a.yaml | 134 +++++++++++++++++ .../crystals/starling_fe_restricted_b.yaml | 135 +++++++++++++++++ .../crystals/starling_fe_restricted_c.yaml | 135 +++++++++++++++++ .../crystals/starling_fe_restricted_d.yaml | 136 ++++++++++++++++++ 4 files changed, 540 insertions(+) create mode 100644 config/experiments/crystals/starling_fe_restricted_a.yaml create mode 100644 config/experiments/crystals/starling_fe_restricted_b.yaml create mode 100644 config/experiments/crystals/starling_fe_restricted_c.yaml create mode 100644 config/experiments/crystals/starling_fe_restricted_d.yaml diff --git a/config/experiments/crystals/starling_fe_restricted_a.yaml b/config/experiments/crystals/starling_fe_restricted_a.yaml new file mode 100644 index 00000000..161303e1 --- /dev/null +++ b/config/experiments/crystals/starling_fe_restricted_a.yaml @@ -0,0 +1,134 @@ +# @package _global_ +# +# Restricted sampling A: only elements O and Fe, with maximum 10 atoms per element +# +# Forward trajectories (10) + Replay buffer (5) + Train set (5) +# Learning rate decay + +defaults: + - override /env: crystals/crystal + - override /gflownet: trajectorybalance + - override /proxy: crystals/dave + - override /logger: wandb + +device: cpu + +# Environment +env: + do_composition_to_sg_constraints: False + do_sg_to_composition_constraints: True + do_sg_to_lp_constraints: True + do_sg_before_composition: True + composition_kwargs: + elements: [8, 26] + max_diff_elem: 5 + min_diff_elem: 1 + min_atoms: 1 + max_atoms: 80 + min_atom_i: 1 + max_atom_i: 10 + do_charge_check: True + space_group_kwargs: + space_groups_subset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 25, 26, 29, 30, 31, 33, 36, 38, 40, 41, 43, 44, 46, 47, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 71, 72, 74, 82, 84, 85, 86, 87, 88, 92, 99, 102, 107, 113, 114, 121, 122, 123, 126, 129, 131, 136, 137, 138, 139, 140, 141, 146, 147, 148, 150, 155, 156, 160, 161, 162, 163, 164, 166, 167, 176, 181, 185, 186, 187, 189, 192, 194, 198, 199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230] + lattice_parameters_kwargs: + min_length: 0.9 + max_length: 100.0 + min_angle: 50.0 + max_angle: 150.0 + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + buffer: + replay_capacity: 1000 + train: + type: csv + path: /network/projects/crystalgfn/data/eform/train.csv + test: + type: csv + path: /network/projects/crystalgfn/data/eform/val.csv + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 10 + backward_replay: 5 + backward_dataset: 5 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 100000 + lr_decay_period: 11000 + lr_decay_gamma: 0.5 + replay_sampling: weighted + train_sampling: permutation + +# Policy +policy: + forward: + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: forward + backward: + type: mlp + n_hid: 256 + n_layers: 3 + shared_weights: False + checkpoint: backward + +# Proxy (eform) +proxy: + reward_min: 1e-08 + do_clip_rewards: True + release: 0.3.4 # Formation energy release + # Boltzmann (exponential), with negative beta because the formation energy is negative and the lower the better + reward_function: exponential + # Parameters of the reward function + reward_function_kwargs: + beta: -8.0 + alpha: 1.0 + +# Evaluator +evaluator: + first_it: False + period: -1 + checkpoints_period: 500 + n_trajs_logprobs: 100 + logprobs_batch_size: 10 + n: 10 + n_top_k: 5000 + top_k: 100 + top_k_period: -1 + +# WandB +logger: + lightweight: True + project_name: "crystal-gfn" + tags: + - gflownet + - crystals + - stack + - matbench + - formationenergy + do: + online: true + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/crystalgfn/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S_%f} diff --git a/config/experiments/crystals/starling_fe_restricted_b.yaml b/config/experiments/crystals/starling_fe_restricted_b.yaml new file mode 100644 index 00000000..2c6e08fa --- /dev/null +++ b/config/experiments/crystals/starling_fe_restricted_b.yaml @@ -0,0 +1,135 @@ +# @package _global_ +# +# Restricted sampling B: Composition Li-Mn-O +# - max_diff_elem and min_diff_elem are set to 3 +# +# Forward trajectories (10) + Replay buffer (5) + Train set (5) +# Learning rate decay + +defaults: + - override /env: crystals/crystal + - override /gflownet: trajectorybalance + - override /proxy: crystals/dave + - override /logger: wandb + +device: cpu + +# Environment +env: + do_composition_to_sg_constraints: False + do_sg_to_composition_constraints: True + do_sg_to_lp_constraints: True + do_sg_before_composition: True + composition_kwargs: + elements: [3, 8, 25] + max_diff_elem: 3 + min_diff_elem: 3 + min_atoms: 1 + max_atoms: 80 + min_atom_i: 1 + max_atom_i: 16 + do_charge_check: True + space_group_kwargs: + space_groups_subset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 25, 26, 29, 30, 31, 33, 36, 38, 40, 41, 43, 44, 46, 47, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 71, 72, 74, 82, 84, 85, 86, 87, 88, 92, 99, 102, 107, 113, 114, 121, 122, 123, 126, 129, 131, 136, 137, 138, 139, 140, 141, 146, 147, 148, 150, 155, 156, 160, 161, 162, 163, 164, 166, 167, 176, 181, 185, 186, 187, 189, 192, 194, 198, 199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230] + lattice_parameters_kwargs: + min_length: 0.9 + max_length: 100.0 + min_angle: 50.0 + max_angle: 150.0 + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + buffer: + replay_capacity: 1000 + train: + type: csv + path: /network/projects/crystalgfn/data/eform/train.csv + test: + type: csv + path: /network/projects/crystalgfn/data/eform/val.csv + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 10 + backward_replay: 5 + backward_dataset: 5 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 100000 + lr_decay_period: 11000 + lr_decay_gamma: 0.5 + replay_sampling: weighted + train_sampling: permutation + +# Policy +policy: + forward: + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: forward + backward: + type: mlp + n_hid: 256 + n_layers: 3 + shared_weights: False + checkpoint: backward + +# Proxy (eform) +proxy: + reward_min: 1e-08 + do_clip_rewards: True + release: 0.3.4 # Formation energy release + # Boltzmann (exponential), with negative beta because the formation energy is negative and the lower the better + reward_function: exponential + # Parameters of the reward function + reward_function_kwargs: + beta: -8.0 + alpha: 1.0 + +# Evaluator +evaluator: + first_it: False + period: -1 + checkpoints_period: 500 + n_trajs_logprobs: 100 + logprobs_batch_size: 10 + n: 10 + n_top_k: 5000 + top_k: 100 + top_k_period: -1 + +# WandB +logger: + lightweight: True + project_name: "crystal-gfn" + tags: + - gflownet + - crystals + - stack + - matbench + - formationenergy + do: + online: true + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/crystalgfn/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S_%f} diff --git a/config/experiments/crystals/starling_fe_restricted_c.yaml b/config/experiments/crystals/starling_fe_restricted_c.yaml new file mode 100644 index 00000000..7d89e214 --- /dev/null +++ b/config/experiments/crystals/starling_fe_restricted_c.yaml @@ -0,0 +1,135 @@ +# @package _global_ +# +# Restricted sampling C: only cubic lattices +# Space groups: 194,198,199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230 +# +# Forward trajectories (10) + Replay buffer (5) + Train set (5) +# Learning rate decay + +defaults: + - override /env: crystals/crystal + - override /gflownet: trajectorybalance + - override /proxy: crystals/dave + - override /logger: wandb + +device: cpu + +# Environment +env: + do_composition_to_sg_constraints: False + do_sg_to_composition_constraints: True + do_sg_to_lp_constraints: True + do_sg_before_composition: True + composition_kwargs: + elements: [1, 3, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 19, 23, 25, 26, 27, 28, 29, 34] + max_diff_elem: 5 + min_diff_elem: 1 + min_atoms: 1 + max_atoms: 80 + min_atom_i: 1 + max_atom_i: 16 + do_charge_check: True + space_group_kwargs: + space_groups_subset: [194,198,199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230] + lattice_parameters_kwargs: + min_length: 0.9 + max_length: 100.0 + min_angle: 50.0 + max_angle: 150.0 + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + buffer: + replay_capacity: 1000 + train: + type: csv + path: /network/projects/crystalgfn/data/eform/train.csv + test: + type: csv + path: /network/projects/crystalgfn/data/eform/val.csv + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 10 + backward_replay: 5 + backward_dataset: 5 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 100000 + lr_decay_period: 11000 + lr_decay_gamma: 0.5 + replay_sampling: weighted + train_sampling: permutation + +# Policy +policy: + forward: + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: forward + backward: + type: mlp + n_hid: 256 + n_layers: 3 + shared_weights: False + checkpoint: backward + +# Proxy (eform) +proxy: + reward_min: 1e-08 + do_clip_rewards: True + release: 0.3.4 # Formation energy release + # Boltzmann (exponential), with negative beta because the formation energy is negative and the lower the better + reward_function: exponential + # Parameters of the reward function + reward_function_kwargs: + beta: -8.0 + alpha: 1.0 + +# Evaluator +evaluator: + first_it: False + period: -1 + checkpoints_period: 500 + n_trajs_logprobs: 100 + logprobs_batch_size: 10 + n: 10 + n_top_k: 5000 + top_k: 100 + top_k_period: -1 + +# WandB +logger: + lightweight: True + project_name: "crystal-gfn" + tags: + - gflownet + - crystals + - stack + - matbench + - formationenergy + do: + online: true + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/crystalgfn/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S_%f} diff --git a/config/experiments/crystals/starling_fe_restricted_d.yaml b/config/experiments/crystals/starling_fe_restricted_d.yaml new file mode 100644 index 00000000..a70235f2 --- /dev/null +++ b/config/experiments/crystals/starling_fe_restricted_d.yaml @@ -0,0 +1,136 @@ +# @package _global_ +# +# Restricted sampling D: lattice parameters range restricted to +# - lengths: 10-20 angstroms +# - angles: 75-135 +# +# Forward trajectories (10) + Replay buffer (5) + Train set (5) +# Learning rate decay + +defaults: + - override /env: crystals/crystal + - override /gflownet: trajectorybalance + - override /proxy: crystals/dave + - override /logger: wandb + +device: cpu + +# Environment +env: + do_composition_to_sg_constraints: False + do_sg_to_composition_constraints: True + do_sg_to_lp_constraints: True + do_sg_before_composition: True + composition_kwargs: + elements: [1, 3, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 19, 23, 25, 26, 27, 28, 29, 34] + max_diff_elem: 5 + min_diff_elem: 1 + min_atoms: 1 + max_atoms: 80 + min_atom_i: 1 + max_atom_i: 16 + do_charge_check: True + space_group_kwargs: + space_groups_subset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 25, 26, 29, 30, 31, 33, 36, 38, 40, 41, 43, 44, 46, 47, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 71, 72, 74, 82, 84, 85, 86, 87, 88, 92, 99, 102, 107, 113, 114, 121, 122, 123, 126, 129, 131, 136, 137, 138, 139, 140, 141, 146, 147, 148, 150, 155, 156, 160, 161, 162, 163, 164, 166, 167, 176, 181, 185, 186, 187, 189, 192, 194, 198, 199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230] + lattice_parameters_kwargs: + min_length: 10.0 + max_length: 20.0 + min_angle: 75.0 + max_angle: 135.0 + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + buffer: + replay_capacity: 1000 + train: + type: csv + path: /network/projects/crystalgfn/data/eform/train.csv + test: + type: csv + path: /network/projects/crystalgfn/data/eform/val.csv + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 10 + backward_replay: 5 + backward_dataset: 5 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 100000 + lr_decay_period: 11000 + lr_decay_gamma: 0.5 + replay_sampling: weighted + train_sampling: permutation + +# Policy +policy: + forward: + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: forward + backward: + type: mlp + n_hid: 256 + n_layers: 3 + shared_weights: False + checkpoint: backward + +# Proxy (eform) +proxy: + reward_min: 1e-08 + do_clip_rewards: True + release: 0.3.4 # Formation energy release + # Boltzmann (exponential), with negative beta because the formation energy is negative and the lower the better + reward_function: exponential + # Parameters of the reward function + reward_function_kwargs: + beta: -8.0 + alpha: 1.0 + +# Evaluator +evaluator: + first_it: False + period: -1 + checkpoints_period: 500 + n_trajs_logprobs: 100 + logprobs_batch_size: 10 + n: 10 + n_top_k: 5000 + top_k: 100 + top_k_period: -1 + +# WandB +logger: + lightweight: True + project_name: "crystal-gfn" + tags: + - gflownet + - crystals + - stack + - matbench + - formationenergy + do: + online: true + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/crystalgfn/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S_%f}