Skip to content

Commit

Permalink
Merge pull request #285 from alexhernandezgarcia/scrabble-rebased
Browse files Browse the repository at this point in the history
Scrabble environment
  • Loading branch information
alexhernandezgarcia authored Feb 12, 2024
2 parents 62c9d52 + f6baedd commit 3e89eb7
Show file tree
Hide file tree
Showing 22 changed files with 468,660 additions and 19 deletions.
15 changes: 15 additions & 0 deletions config/env/scrabble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defaults:
- base

_target_: gflownet.envs.scrabble.Scrabble

id: scrabble
# Buffer
buffer:
data_path: null
train: null
test:
type: uniform
n: 10
output_csv: scrabble_test.csv
output_pkl: scrabble_test.pkl
73 changes: 73 additions & 0 deletions config/experiments/scrabble/jay.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# @package _global_
# A configuration that works well with the corners proxy.
# wandb: https://wandb.ai/alexhg/cube/runs/9u2d3zzh

defaults:
- override /env: scrabble
- override /gflownet: trajectorybalance
- override /proxy: scrabble
- override /logger: wandb
- override /user: alex

# Environment
env:
# Buffer
buffer:
data_path: null
train: null
test:
type: random
n: 1000
output_csv: scrabble_test.csv
output_pkl: scrabble_test.pkl
reward_func: identity

# Proxy
proxy:
vocabulary_check: True

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
optimizer:
batch_size:
forward: 100
lr: 0.0001
z_dim: 16
lr_z_mult: 100
n_train_steps: 10000

# Policy
policy:
forward:
type: mlp
n_hid: 512
n_layers: 5
checkpoint: forward
backward:
type: mlp
n_hid: 512
n_layers: 5
shared_weights: False
checkpoint: backward

# WandB
logger:
do:
online: true
lightweight: True
project_name: "scrabble"
tags:
- gflownet
- discrete
- scrabble
test:
period: 500
n: 1000
checkpoints:
period: 500

# Hydra
hydra:
run:
dir: ${user.logdir.root}/debug/ccube/${now:%Y-%m-%d_%H-%M-%S}
73 changes: 73 additions & 0 deletions config/experiments/scrabble/penguin.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# @package _global_
# A configuration that works well with the corners proxy.
# wandb: https://wandb.ai/alexhg/cube/runs/9u2d3zzh

defaults:
- override /env: scrabble
- override /gflownet: trajectorybalance
- override /proxy: scrabble
- override /logger: wandb
- override /user: alex

# Environment
env:
# Buffer
buffer:
data_path: null
train: null
test:
type: random
n: 1000
output_csv: scrabble_test.csv
output_pkl: scrabble_test.pkl
reward_func: identity

# Proxy
proxy:
vocabulary_check: True

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.2
optimizer:
batch_size:
forward: 100
lr: 0.001
z_dim: 16
lr_z_mult: 100
n_train_steps: 10000

# Policy
policy:
forward:
type: mlp
n_hid: 1024
n_layers: 4
checkpoint: forward
backward:
type: mlp
n_hid: 1024
n_layers: 3
shared_weights: False
checkpoint: backward

# WandB
logger:
do:
online: true
lightweight: True
project_name: "scrabble"
tags:
- gflownet
- discrete
- scrabble
test:
period: 500
n: 1000
checkpoints:
period: 500

# Hydra
hydra:
run:
dir: ${user.logdir.root}/debug/ccube/${now:%Y-%m-%d_%H-%M-%S}
3 changes: 3 additions & 0 deletions config/proxy/scrabble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: gflownet.proxy.scrabble.ScrabbleScorer

vocabulary_check: False
3 changes: 1 addition & 2 deletions config/user/alex.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
logdir:
root: /home/alex/logs/gflownet
data:
root: /home/mila/h/hernanga/gflownet/data
alanine_dipeptide: /home/mila/h/hernanga/gflownet/data/alanine_dipeptide_conformers_1.npy
root: /home/alex/datasets
11 changes: 5 additions & 6 deletions gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,16 +697,17 @@ def state2proxy(
self, state: Union[List, TensorType["state_dim"]] = None
) -> TensorType["state_proxy_dim"]:
"""
Prepares a state in "GFlowNet format" for the proxy. By default, states2proxy
is called, which by default will return the state as is.
Prepares a single state in "GFlowNet format" for the proxy. By default, simply
states2proxy is called and the output will be a "batch" with a single state in
the proxy format.
Args
----
state : list
A state
"""
state = self._get_state(state)
return torch.squeeze(self.states2proxy([state]), dim=0)
return self.states2proxy([state])

def states2policy(
self, states: Union[List, TensorType["batch", "state_dim"]]
Expand Down Expand Up @@ -771,9 +772,7 @@ def reward(self, state=None, done=None, do_non_terminating=False):
done = self._get_done(done)
if not done and not do_non_terminating:
return tfloat(0.0, float_type=self.float, device=self.device)
return self.proxy2reward(
self.proxy(torch.unsqueeze(self.state2proxy(state), dim=0))[0]
)
return self.proxy2reward(self.proxy(self.state2proxy(state))[0])

# TODO: cleanup
def reward_batch(self, states: List[List], done=None):
Expand Down
Loading

0 comments on commit 3e89eb7

Please sign in to comment.