Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scrabble env #281

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
300c659
Remove words complexity from init and docstring and complete get_acti…
alexhernandezgarcia Oct 27, 2023
c2bfc37
Fixes and cleanup; mask forward
alexhernandezgarcia Oct 27, 2023
654a381
First version of tests.
alexhernandezgarcia Oct 27, 2023
ca5c10c
All main methods complete and tested; conversions are missing.
alexhernandezgarcia Oct 27, 2023
6339757
Base sequence finished
alexhernandezgarcia Oct 30, 2023
aecea69
Make tokens a tuple and do not sort
alexhernandezgarcia Oct 30, 2023
de32d25
DNA environment via sequence
alexhernandezgarcia Oct 30, 2023
bc65211
Merge branch 'sequence-envs' into scrabble
alexhernandezgarcia Nov 2, 2023
0e16d4e
merge
alexhernandezgarcia Nov 2, 2023
38bf3d2
Uniform proxy accepts list as input
alexhernandezgarcia Nov 2, 2023
bbcc831
get_uniform_terminating_states for sequences; fixes
alexhernandezgarcia Nov 2, 2023
41da9aa
Common tests for sequence
alexhernandezgarcia Nov 2, 2023
7935315
Fix common test: copy state before adding to list
alexhernandezgarcia Nov 2, 2023
d78df71
Sequence config
alexhernandezgarcia Nov 2, 2023
83132e0
WIP: Scrabble environment, first steps
alexhernandezgarcia Nov 2, 2023
c560afa
Enable state2proxy to handle states2proxy that return a list.
alexhernandezgarcia Nov 2, 2023
9c01cc6
Remove padding of readables of sequences
alexhernandezgarcia Nov 2, 2023
8010ba6
Scrabble: version 0 done
alexhernandezgarcia Nov 2, 2023
a86e4a2
Scrabble proxy
alexhernandezgarcia Nov 2, 2023
44fb11b
Add max length to sequence
alexhernandezgarcia Nov 2, 2023
2b84cc7
Fix no-pad sequences in scrabble proxy
alexhernandezgarcia Nov 2, 2023
bf5dbaf
Add scrabble env config
alexhernandezgarcia Nov 2, 2023
fb79eec
Make vocabulary of 7 letters or fewer and add option to not check voc…
alexhernandezgarcia Nov 3, 2023
c29b9d6
Scrabble demo
alexhernandezgarcia Nov 3, 2023
8f6c0bc
Scrabble proxy format is a tensor of indices
alexhernandezgarcia Nov 3, 2023
9e7e8ad
Adapt scrabble proxy to work tensors for better efficiency
alexhernandezgarcia Nov 3, 2023
5d8b38f
Re-enable vocabulary check flag
alexhernandezgarcia Nov 3, 2023
820863d
Scrabble demo env
alexhernandezgarcia Nov 10, 2023
7a3947e
Update config alex
alexhernandezgarcia Nov 10, 2023
ac6cd74
Scrabble env config
alexhernandezgarcia Nov 10, 2023
7b0bb2e
Remove all seqs stuff and leave only scrabble
alexhernandezgarcia Feb 1, 2024
287e0de
Clean up dna and sequence files and fix import
alexhernandezgarcia Feb 1, 2024
f786773
Various fixes
alexhernandezgarcia Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions config/env/scrabble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defaults:
- base

_target_: gflownet.envs.scrabble.Scrabble

id: scrabble
# Buffer
buffer:
data_path: null
train: null
test:
type: uniform
n: 10
output_csv: scrabble_test.csv
output_pkl: scrabble_test.pkl
3 changes: 3 additions & 0 deletions config/proxy/scrabble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: gflownet.proxy.scrabble.ScrabbleScorer

vocabulary_check: False
5 changes: 2 additions & 3 deletions config/user/alex.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
logdir:
root: /network/scratch/h/hernanga/logs/gflownet
root: /home/alex/logs/gflownet
data:
root: /home/mila/h/hernanga/gflownet/data
alanine_dipeptide: /home/mila/h/hernanga/gflownet/data/alanine_dipeptide_conformers_1.npy
root: /home/alex/datasets
11 changes: 10 additions & 1 deletion gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,16 @@ def state2proxy(self, state: Union[List, TensorType["state_dim"]] = None):
A state
"""
state = self._get_state(state)
return torch.squeeze(self.states2proxy([state]), dim=0)
state_proxy = self.states2proxy([state])
if isinstance(state_proxy, list):
return state_proxy[0]
elif torch.is_tensor(state_proxy):
return torch.squeeze(state_proxy, dim=0)
else:
raise NotImplementedError(
"The output of states2proxy must be either a list or a tensor. "
f"Got {type(state_proxy)}."
)

def states2policy(
self, states: Union[List, TensorType["batch", "state_dim"]]
Expand Down
Loading
Loading