Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] branch with discrete crystal env experiments #224

Closed
wants to merge 36 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
861fe62
remove missing elements padding
vict0rsch Sep 24, 2023
335cecb
add exp file
vict0rsch Sep 25, 2023
accd761
update params from discussion with Alex
vict0rsch Sep 25, 2023
1e59abc
Fix spacegroup _is_compatible()
carriepl-mila Sep 25, 2023
41d7659
default boltzmann
vict0rsch Sep 25, 2023
77a9c32
handle logger notes in wandb
vict0rsch Sep 25, 2023
3813bcc
describe notes config
vict0rsch Sep 25, 2023
802f218
rename to `discrete-matbench`
vict0rsch Sep 25, 2023
c69d6af
update hydra run dir
vict0rsch Sep 25, 2023
dbe9cf5
handle per-job git repo with `--code_dir='$SLURM_TMPDIR'`
vict0rsch Sep 25, 2023
8b64e69
user confirmation even for just for missing `git_checkout`
vict0rsch Sep 25, 2023
22d5519
improve git warning logic
vict0rsch Sep 25, 2023
7e98702
handle quotes in generated command-line
vict0rsch Sep 25, 2023
7bbd17a
strip output for no new line
vict0rsch Sep 25, 2023
c1b489a
fix quotes and ssh to https
vict0rsch Sep 25, 2023
c42826e
print new line
vict0rsch Sep 25, 2023
d1f074b
handle possible "=" in notes
vict0rsch Sep 25, 2023
40bbc92
quote both key AND value if = in CLI
vict0rsch Sep 25, 2023
f26d7fa
typo: removed first level quoting
vict0rsch Sep 25, 2023
ea9db8c
improve dave docstring AND fix sg-1
vict0rsch Sep 26, 2023
406f18b
auto parse repo name
vict0rsch Sep 26, 2023
7acd17f
minor release: 0.4.0
vict0rsch Sep 26, 2023
0cc80ae
0.3.3
vict0rsch Sep 26, 2023
7380865
bump dave version
vict0rsch Sep 26, 2023
09d4346
nest sbatch dicts further
vict0rsch Sep 27, 2023
4675707
Space group now accepts an iterable of valid space groups to restrict…
alexhernandezgarcia Sep 23, 2023
9954269
update configs: non-nested policy
vict0rsch Sep 27, 2023
ee5e0d8
smaller seach space
vict0rsch Sep 27, 2023
9510785
fix missing .sh extension
vict0rsch Sep 27, 2023
0dc22c2
fix int parsing
vict0rsch Sep 27, 2023
472c3a6
add wandb query
vict0rsch Sep 27, 2023
659c373
stateless get_mask/get_parents
michalkoziarski Sep 28, 2023
ceba97b
stateless get_parents
michalkoziarski Sep 28, 2023
626be08
added TODO
michalkoziarski Sep 28, 2023
d9ffa25
Merge pull request #236 from alexhernandezgarcia/discrete-matbench-bu…
vict0rsch Sep 29, 2023
d40cc5a
add missing params
vict0rsch Sep 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions LAUNCH.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ usage: launch.py [-h] [--help-md] [--job_name JOB_NAME] [--outdir OUTDIR]
[--cpus_per_task CPUS_PER_TASK] [--mem MEM] [--gres GRES]
[--partition PARTITION] [--modules MODULES]
[--conda_env CONDA_ENV] [--venv VENV] [--template TEMPLATE]
[--code_dir CODE_DIR] [--jobs JOBS] [--dry-run] [--verbose]
[--force]
[--code_dir CODE_DIR] [--git_checkout GIT_CHECKOUT]
[--jobs JOBS] [--dry-run] [--verbose] [--force]

optional arguments:
-h, --help show this help message and exit
Expand All @@ -35,6 +35,11 @@ optional arguments:
$root/mila/sbatch/template-conda.sh
--code_dir CODE_DIR cd before running main.py (defaults to here). Defaults
to $root
--git_checkout GIT_CHECKOUT
Branch or commit to checkout before running the code.
This is only used if --code_dir='$SLURM_TMPDIR'. If
not specified, the current branch is used. Defaults to
None
--jobs JOBS jobs (nested) file name in external/jobs (with or
without .yaml). Or an absolute path to a yaml file
anywhere Defaults to None
Expand All @@ -54,6 +59,7 @@ conda_env : gflownet
cpus_per_task : 2
dry-run : False
force : False
git_checkout : None
gres : gpu:1
job_name : gflownet
jobs : None
Expand Down
2 changes: 2 additions & 0 deletions config/env/crystals/crystal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ _target_: gflownet.envs.crystals.crystal.Crystal
id: crystal
composition_kwargs:
elements: 89
max_atoms: 20
max_atom_i: 16
lattice_parameters_kwargs:
min_length: 1.0
max_length: 5.0
Expand Down
3 changes: 3 additions & 0 deletions config/env/crystals/spacegroup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ defaults:
_target_: gflownet.envs.crystals.spacegroup.SpaceGroup

id: spacegroup

# Subset of space groups
space_groups_subset: null
# Stoichiometry
n_atoms: null

Expand Down
76 changes: 76 additions & 0 deletions config/experiments/workshop23/discrete-matbench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# @package _global_

defaults:
- override /env: crystals/crystal
- override /gflownet: trajectorybalance
- override /proxy: crystals/dave
- override /logger: wandb

device: cpu

# Environment
env:
lattice_parameters_kwargs:
min_length: 1.0
max_length: 350.0
min_angle: 50.0
max_angle: 150.0
grid_size: 10
composition_kwargs:
elements: [1,3,4,5,6,7,8,9,11,12,13,14,15,16,17,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,89,90,91,92,93,94]
reward_func: boltzmann
reward_beta: 1
buffer:
replay_capacity: 0

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
optimizer:
batch_size:
forward: 10
backward_replay: -1
lr: 0.001
z_dim: 16
lr_z_mult: 100
n_train_steps: 10000
lr_decay_period: 1000000
replay_sampling: weighted

policy:
forward:
type: mlp
n_hid: 512
n_layers: 5
checkpoint: forward
backward:
type: mlp
n_hid: 512
n_layers: 5
shared_weights: False
checkpoint: backward

# WandB
logger:
lightweight: True
project_name: "crystal-gfn"
tags:
- gflownet
- crystals
- matbench
- workshop23
checkpoints:
period: 500
do:
online: true
test:
period: -1
n: 500
n_top_k: 5000
top_k: 100
top_k_period: -1

# Hydra
hydra:
run:
dir: ${user.logdir.root}/workshop23/discrete-matbench/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S}
78 changes: 78 additions & 0 deletions config/experiments/workshop23/mini-discrete-matbench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# @package _global_

defaults:
- override /env: crystals/crystal
- override /gflownet: trajectorybalance
- override /proxy: crystals/dave
- override /logger: wandb

device: cpu

# Environment
env:
lattice_parameters_kwargs:
min_length: 1.0
max_length: 50.0
min_angle: 50.0
max_angle: 150.0
grid_size: 20
composition_kwargs:
elements: [1, 3, 6, 7, 8, 9, 12, 14, 15, 16, 17, 26]
space_group_kwargs:
space_groups_subset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 25, 26, 29, 30, 31, 33, 36, 38, 40, 41, 43, 44, 46, 47, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 71, 72, 74, 82, 84, 85, 86, 87, 88, 92, 99, 102, 107, 113, 114, 121, 122, 123, 126, 129, 131, 136, 137, 138, 139, 140, 141, 146, 147, 148, 150, 155, 156, 160, 161, 162, 163, 164, 166, 167, 176, 181, 185, 186, 187, 189, 192, 194, 198, 199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230]
reward_func: boltzmann
reward_beta: 1
buffer:
replay_capacity: 0

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
optimizer:
batch_size:
forward: 10
backward_replay: -1
lr: 0.001
z_dim: 16
lr_z_mult: 100
n_train_steps: 10000
lr_decay_period: 1000000
replay_sampling: weighted

policy:
forward:
type: mlp
n_hid: 512
n_layers: 5
checkpoint: forward
backward:
type: mlp
n_hid: 512
n_layers: 5
shared_weights: False
checkpoint: backward

# WandB
logger:
lightweight: True
project_name: "crystal-gfn"
tags:
- gflownet
- crystals
- matbench
- workshop23
checkpoints:
period: 500
do:
online: true
test:
period: -1
n: 500
n_top_k: 5000
top_k: 100
top_k_period: -1

# Hydra
hydra:
run:
dir: ${user.logdir.root}/workshop23/discrete-matbench/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S}
1 change: 1 addition & 0 deletions config/logger/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ debug: False
lightweight: False
progress: True
context: "0"
notes: null # wandb run notes (e.g. "baseline")
2 changes: 1 addition & 1 deletion config/proxy/crystals/dave.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: gflownet.proxy.crystals.dave.DAVE

release: 0.3.2
release: 0.3.4
ckpt_path:
mila: /network/scratch/s/schmidtv/crystals-proxys/proxy-ckpts/
victor: ~/Documents/Github/ActiveLearningMaterials/checkpoints/980065c0/checkpoints-3ff648a2
Expand Down
46 changes: 25 additions & 21 deletions gflownet/envs/crystals/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,19 +307,18 @@ def get_mask_invalid_actions_forward(
] = space_group_mask
elif stage == Stage.LATTICE_PARAMETERS:
"""
TODO: to be stateless (meaning, operating as a function, not a method with
current object context) this needs to set lattice system based on the passed
state only. Right now it uses the current LatticeParameter environment, in
particular the lattice system that it was set to, and that changes the invalid
actions mask.

If for some reason a state will be passed to this method that describes an
object with different lattice system than what self.lattice_system contains,
the result will be invalid.
TODO: refactor below implementation such that it remains stateless,
but doesn't require creation of LatticeParameters object every time.
"""
lattice_system = self.space_group.get_lattice_system(
self._get_space_group_state(state)
)
lattice_parameters = LatticeParameters(
lattice_system=lattice_system, **self.lattice_parameters_kwargs
)
lattice_parameters_state = self._get_lattice_parameters_state(state)
lattice_parameters_mask = (
self.lattice_parameters.get_mask_invalid_actions_forward(
lattice_parameters.get_mask_invalid_actions_forward(
state=lattice_parameters_state, done=False
)
)
Expand Down Expand Up @@ -433,10 +432,16 @@ def get_parents(
)
parents = [self._build_state(p, Stage.COMPOSITION) for p in parents]
actions = [self._pad_action(a, Stage.COMPOSITION) for a in actions]
# TODO: refactor source check
elif stage == Stage.SPACE_GROUP or (
stage == Stage.LATTICE_PARAMETERS
and self._get_lattice_parameters_state(state)
== self.lattice_parameters.source
== LatticeParameters(
lattice_system=self.space_group.get_lattice_system(
self._get_space_group_state(state)
),
**self.lattice_parameters_kwargs,
).source
):
space_group_done = stage == Stage.LATTICE_PARAMETERS
parents, actions = self.space_group.get_parents(
Expand All @@ -446,17 +451,16 @@ def get_parents(
actions = [self._pad_action(a, Stage.SPACE_GROUP) for a in actions]
elif stage == Stage.LATTICE_PARAMETERS:
"""
TODO: to be stateless (meaning, operating as a function, not a method with
current object context) this needs to set lattice system based on the passed
state only. Right now it uses the current LatticeParameter environment, in
particular the lattice system that it was set to, and that changes the invalid
actions mask.

If for some reason a state will be passed to this method that describes an
object with different lattice system than what self.lattice_system contains,
the result will be invalid.
TODO: refactor below implementation such that it remains stateless,
but doesn't require creation of LatticeParameters object every time.
"""
parents, actions = self.lattice_parameters.get_parents(
lattice_system = self.space_group.get_lattice_system(
self._get_space_group_state(state)
)
lattice_parameters = LatticeParameters(
lattice_system=lattice_system, **self.lattice_parameters_kwargs
)
parents, actions = lattice_parameters.get_parents(
state=self._get_lattice_parameters_state(state), done=done
)
parents = [self._build_state(p, Stage.LATTICE_PARAMETERS) for p in parents]
Expand Down
Loading