Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor & standardize evaluation with Evaluator #287

Merged
merged 107 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
bbfbf76
v0: move configs and base API
vict0rsch Feb 14, 2024
d56af9e
add `eval_config` arg to `GFlowNetAgent` init
vict0rsch Feb 15, 2024
505b402
WIP: towards evaluator
vict0rsch Feb 15, 2024
a856791
rename to `eval_top_k`
vict0rsch Feb 15, 2024
3216ae9
Update for new `evaluator.eval()` api
vict0rsch Feb 15, 2024
294e9b5
quote in print
vict0rsch Feb 16, 2024
a24efa3
no figs as empty dicts instead of `(None,)`
vict0rsch Feb 16, 2024
70252e5
fix `self` to `gfn` in `eval_top_k`
vict0rsch Feb 16, 2024
9769016
`load_gflow_net_from_run_path` returns a tuple
vict0rsch Feb 16, 2024
bd212ad
move legacy code
vict0rsch Feb 16, 2024
73b1b15
`load_gflow_net_from_run_path` returns a tuple
vict0rsch Feb 16, 2024
cddc5d6
DOOOCSTRIIIINGS
vict0rsch Feb 16, 2024
54f2f23
`@classmethod`
vict0rsch Feb 18, 2024
f546506
unused `log_iter`
vict0rsch Feb 19, 2024
1225034
GFNA init docstring
vict0rsch Feb 19, 2024
f7662d3
refactor `should_` `train/eval/checkpoint`etc.
vict0rsch Feb 19, 2024
924dc18
don't log `None` values
vict0rsch Feb 19, 2024
274ff79
No need for a dedicated `log_test_metrics`
vict0rsch Feb 19, 2024
f75c876
move figs to `plot(...)`
vict0rsch Feb 19, 2024
1fdb401
setup `requires` system
vict0rsch Feb 19, 2024
baf701a
allow for custom `require`
vict0rsch Feb 19, 2024
00fc1b5
typo returned dict
vict0rsch Feb 19, 2024
c1a5dc8
move log prob metrcis to `compute_log_prob_metrics(...)`
vict0rsch Feb 19, 2024
2f08b0d
improve `make_metrics` and `make_requires`
vict0rsch Feb 19, 2024
4659707
refactor `requires`
vict0rsch Feb 19, 2024
987ab7b
typo -> `should_log_train`
vict0rsch Feb 19, 2024
06ccc61
`compute_density_metrics` for `eval()`
vict0rsch Feb 19, 2024
ee38802
add `eval:base` default
vict0rsch Feb 19, 2024
1468755
update configs
vict0rsch Feb 19, 2024
b670c14
move evaluator init later in gfna init
vict0rsch Feb 19, 2024
06d4d06
remove legacy `.test.` references
vict0rsch Feb 19, 2024
5809254
debug print
vict0rsch Feb 19, 2024
4a3b049
fix logdir exists logic and `exit(1)`
vict0rsch Feb 19, 2024
0cd21e0
trailing whitespace
vict0rsch Feb 19, 2024
4802ab2
remove `oracle` references
vict0rsch Feb 19, 2024
d8c9a7b
add `eval` default
vict0rsch Feb 20, 2024
272310c
`_self_` last to allow for overrides in `_self_` to other name spaces
vict0rsch Feb 20, 2024
5160759
`name` -> `display_name`
vict0rsch Feb 20, 2024
975024e
`ALL_REQS` and `ValueError`s
vict0rsch Feb 20, 2024
a532eab
missing tensor `.item()`
vict0rsch Feb 20, 2024
b3b7ff2
move `kde_pred` to continuous density metrics only
vict0rsch Feb 20, 2024
610dcfc
store pkl & csv paths as `Buffer` attributes
vict0rsch Feb 20, 2024
c4294f6
Imrpove robustness and allow `dict` metrics to `make_metrics`
vict0rsch Feb 20, 2024
0af5719
utils for tests file
vict0rsch Feb 20, 2024
d51b4b6
+ `gflownet_from_config`
vict0rsch Feb 20, 2024
25db913
generic fixtures
vict0rsch Feb 20, 2024
acc56f8
first tests for `gflownet.eval.base.GFlowNetEvaluator`
vict0rsch Feb 20, 2024
93acac1
clean up `oracle` files and `legacy.py`
vict0rsch Feb 20, 2024
e740e62
refactor `active_learning` to `use_context`
vict0rsch Feb 20, 2024
bb0486c
Remove `sample_only` gflownet arg (and config) and `make_train_test` …
vict0rsch Feb 20, 2024
0edbc6b
revert standardize `main` with `gflownet_from_config`
vict0rsch Feb 21, 2024
db4f1dc
trailing breakpoint
vict0rsch Feb 21, 2024
829cf4a
Update docstring
vict0rsch Feb 21, 2024
4347939
remove unused
vict0rsch Feb 21, 2024
25cf720
improve example
vict0rsch Feb 21, 2024
9f50b45
move `from_agent` and `from_dir` methods
vict0rsch Feb 21, 2024
3c5186b
use `gflownet_from_config` in `load_gflow_net_from_run_path`
vict0rsch Feb 21, 2024
df6d9af
`empty_ok=False` arg
vict0rsch Feb 21, 2024
626d34c
clean up example
vict0rsch Feb 21, 2024
a7018bb
document constants
vict0rsch Feb 21, 2024
b6832da
eval top k uses dict data structure
vict0rsch Mar 1, 2024
b6527f8
improve docstrings
vict0rsch Mar 1, 2024
9e09b63
add `update_all_metrics_and_requirements`
vict0rsch Mar 1, 2024
47c6992
have dedicated `plot_kwargs`
vict0rsch Mar 1, 2024
4985d9f
standardize `{"metrics": {}, "data": {}}` return pattern
vict0rsch Mar 1, 2024
054aa61
work on docstrings example
vict0rsch Mar 1, 2024
0ac39c8
towards abstract / base pattern
vict0rsch Mar 1, 2024
d73de32
update example docstring
vict0rsch Mar 1, 2024
cdb66b3
more docs
vict0rsch Mar 1, 2024
473edf9
allow `init` instantiation + more tutorial
vict0rsch Mar 4, 2024
670bdb8
`define_new_metrics`
vict0rsch Mar 4, 2024
82f0781
test `.`
vict0rsch Mar 4, 2024
34082c1
no `.` ?
vict0rsch Mar 4, 2024
43015d1
update links
vict0rsch Mar 4, 2024
d0cdb27
always use `evaluator`
vict0rsch Mar 4, 2024
21663da
reference logger
vict0rsch Mar 4, 2024
33fa8ac
more doc polih
vict0rsch Mar 4, 2024
1939e2a
Improve docs and refactor to `AbstractEvaluator` and `BaseEvaluator`
vict0rsch Mar 5, 2024
ef99765
comment-out trailing dev docs rendering filter
vict0rsch Mar 5, 2024
02e31ef
improve logging
vict0rsch Mar 5, 2024
64794ad
use `evaluator` namesmace
vict0rsch Mar 5, 2024
b3be860
adapt jay
vict0rsch Mar 5, 2024
1d883d3
move metrics to base
vict0rsch Mar 5, 2024
92182b1
fix tests
vict0rsch Mar 5, 2024
8139ab0
clean up prints
vict0rsch Mar 5, 2024
dce8abf
use evaluator
vict0rsch Mar 5, 2024
0322a41
evaluator in tests instantiate
vict0rsch Mar 5, 2024
8b27ac3
improve init docs
vict0rsch Mar 5, 2024
4bed143
outline
vict0rsch Mar 5, 2024
66e457d
typo
vict0rsch Mar 5, 2024
aed7110
add note
vict0rsch Mar 5, 2024
918fc7c
docs `plot` and `eval_top_k`
vict0rsch Mar 5, 2024
fd3fabd
test code-include
vict0rsch Mar 7, 2024
ece173e
Update gflownet/evaluator/__init__.py
carriepl May 30, 2024
93baba4
Apply suggestions from code review - Improve docstrings
carriepl May 30, 2024
43ef77e
Update gflownet/evaluator/__init__.py
carriepl May 30, 2024
7a1f97a
Complete GFlowNetAgent docstring
carriepl-mila May 31, 2024
0b0b0fc
Remove unused variable
carriepl-mila May 31, 2024
5e0f0fa
Fix merge conflicts
carriepl-mila Jun 4, 2024
800314c
Fix pytest filename conflicts
carriepl-mila Jun 4, 2024
a9bea9a
Re-integrate changes from main lost in merge
carriepl-mila Jun 4, 2024
81e4745
Update comments
carriepl-mila Jun 4, 2024
689ed6a
Improve comments in Logger
carriepl-mila Jun 5, 2024
d689270
Adjust sanity check runs (CTorus) to Evaluator config
alexhernandezgarcia Jun 5, 2024
4f4fcad
Fix typo
alexhernandezgarcia Jun 5, 2024
25a45e2
Evaluator: add samples_topk to plot(); add TODOs
alexhernandezgarcia Jun 5, 2024
6efaa19
Update evaluator config of Tetris sanity runs
alexhernandezgarcia Jun 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ The repository supports logging of train and evaluation metrics to [wandb.ai](ht

Bibtex Format

```txt
```text
@misc{hernandez-garcia2024,
author = {Hernandez-Garcia, Alex and Saxena, Nikita and Volokhova, Alexandra and Koziarski, Michał and Sharma, Divya and Viviano, Joseph D and Carrier, Pierre Luc and Schmidt, Victor},
title = {gflownet},
Expand Down
26 changes: 26 additions & 0 deletions config/evaluator/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_target_: gflownet.evaluator.base.BaseEvaluator

# config formerly from logger.test
first_it: True
period: 100
n: 100
kde:
bandwidth: 0.1
kernel: gaussian
n_top_k: 5000
top_k: 100
top_k_period: -1
# Number of backward trajectories to estimate the log likelihood of each test data point
n_trajs_logprobs: 10
logprobs_batch_size: 100
logprobs_bootstrap_size: 10000
# Maximum number of test data points to compute log likelihood probs.
max_data_logprobs: 1e5
# Number of points to obtain a grid to estimate the reward density
n_grid: 40000
train_log_period: 1
checkpoints_period: 1000
# List of metrics as per gflownet/eval/evaluator.py:METRICS_NAMES
# Set to null for all of them
# Values must be comma separated like `metrics: "l1, kl, js"` (spaces are optional)
metrics: all
14 changes: 8 additions & 6 deletions config/experiments/icml23/ctorus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

defaults:
- override /env: ctorus
- override /evaluator: base
- override /gflownet: trajectorybalance
- override /proxy: torus
- override /logger: wandb
Expand Down Expand Up @@ -32,6 +33,12 @@ gflownet:
lr_z_mult: 1000
n_train_steps: 5000

# Evaluator
evaluator:
period: 25
n: 1000
checkpoints_period: 500

# Policy
policy:
forward:
Expand All @@ -50,15 +57,10 @@ policy:
logger:
lightweight: True
project_name: "Continuous GFlowNet"
tags:
tags:
- gflownet
- continuous
- ctorus
test:
period: 25
n: 1000
checkpoints:
period: 500

# Hydra
hydra:
Expand Down
13 changes: 7 additions & 6 deletions config/experiments/scrabble/jay.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
defaults:
- override /env: scrabble
- override /gflownet: trajectorybalance
- override /evaluator: base
- override /proxy: scrabble
- override /logger: wandb
- override /user: alex
Expand Down Expand Up @@ -50,21 +51,21 @@ policy:
shared_weights: False
checkpoint: backward

# Evaluator
Copy link
Collaborator

@carriepl carriepl May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The format for these evaluator arguments is different than in the icml23/ctorus.yaml config file. Is that a problem?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am unsure what this comment refers to exactly, but I would just say that the icml23/ctorus.yaml file is really old (January 2023) so it would be fine to deprecate it / adapt it if needed. Yes, it contains the experiments of a paper, but I believe it's ok to adapt it to the new state of the repo.

period: 500
n: 1000
checkpoints_period: 500

# WandB
logger:
do:
online: true
lightweight: True
project_name: "scrabble"
tags:
tags:
- gflownet
- discrete
- scrabble
test:
period: 500
n: 1000
checkpoints:
period: 500

# Hydra
hydra:
Expand Down
6 changes: 1 addition & 5 deletions config/gflownet/gflownet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,4 @@ replay_sampling: permutation
# Train data set backward sampling
train_sampling: permutation
num_empirical_loss: 200000
oracle:
# Number of samples for oracle metrics
n: 500
sample_only: False
active_learning: False
use_context: False
33 changes: 0 additions & 33 deletions config/logger/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,6 @@ do:

project_name: "GFlowNet"

# Train metrics
train:
period: 1
# Test metrics
test:
first_it: True
period: 100
n: 100
kde:
bandwidth: 0.1
kernel: gaussian
n_top_k: 5000
top_k: 100
top_k_period: -1
# Number of backward trajectories to estimate the log likelihood of each test data point
n_trajs_logprobs: 10
logprobs_batch_size: 100
logprobs_bootstrap_size: 10000
# Maximum number of test data points to compute log likelihood probs.
max_data_logprobs: 1e5
# Number of points to obtain a grid to estimate the reward density
n_grid: 40000
# Oracle metrics
oracle:
period: 100000
k:
- 1
- 10
- 100
# Policy model checkpoints
checkpoints:
period: 1000

# Log dir
logdir:
root: ./logs
Expand Down
2 changes: 1 addition & 1 deletion config/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ defaults:

_target_: gflownet.utils.logger.Logger

tags:
tags:
- gflownet
1 change: 1 addition & 0 deletions config/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ defaults:
- proxy: corners
- logger: wandb
- user: default
- evaluator: base

# Device
device: cuda
Expand Down
3 changes: 2 additions & 1 deletion config/tests.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
defaults:
- _self_
- env: grid
- gflownet: trajectorybalance
- proxy: uniform
- policy: mlp
- logger: base
- user: alex
- evaluator: base
- _self_

# Device
device: cpu
Expand Down
17 changes: 12 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,13 @@
"sphinx_design",
"sphinx_copybutton",
"sphinxext.opengraph",
"code_include.extension",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]


# -- Options for HTML output -------------------------------------------------

Expand Down Expand Up @@ -122,6 +118,7 @@
# sphinx.ext.intersphinx
intersphinx_mapping = {
"torch": ("https://pytorch.org/docs/stable", None),
"omegaconf": ("https://omegaconf.readthedocs.io/en/latest", None),
}

# sphinx.ext.autodoc & autoapi.extension
Expand Down Expand Up @@ -179,3 +176,13 @@
"enable": True,
"image": "./_static/images/gflownet-logo.png",
}


# def skip_util_classes(app, what, name, obj, skip, options):
# return any(
# name.startswith(f"gflownet.{p}") for p in ["envs", "proxy", "policy", "utils"]
# )


# def setup(sphinx):
# sphinx.connect("autoapi-skip-member", skip_util_classes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this part is meant to do. Is that something outdates that should be removed from the PR? Or is this a work in progress that should be finished and then uncommented?

16 changes: 11 additions & 5 deletions docs/contributors/example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ Remember, this works in docstrings *and* in stand-alone ``.rst`` files.

Cool features:

Reference to a class: :class:`gflownet.proxy.crystals.dave.DAVE` (long), or another
:class:`~gflownet.gflownet.GFlowNetAgent` or to a method:
:meth:`~gflownet.gflownet.GFlowNetAgent.trajectorybalance_loss`
or to an external function :func:`torch.cuda.synchronize()`
(this <- needs to be listed in ``docs/conf.py:intersphinx_mapping``).
Reference code docs of:

- A class: :class:`gflownet.envs.grid.Grid` (long format)
- Another class :class:`~gflownet.gflownet.GFlowNetAgent` (short format, by prepending ``~``)
- A method :meth:`~gflownet.gflownet.GFlowNetAgent.trajectorybalance_loss`
- Or even an external function :func:`torch.cuda.synchronize()`

.. note

External content should be listed in ``docs/conf.py:intersphinx_mapping``.
More info in the `Read The Docs documentation <https://docs.readthedocs.io/en/stable/guides/intersphinx.html>`_.

An actual tutorial on ``.rst``:
`ReStructured Text for those who know Markdown <https://docs.open-mpi.org/en/v5.0.x/developers/rst-for-markdown-expats.html#hyperlinks-to-urls>`_
Expand Down
3 changes: 2 additions & 1 deletion docs/contributors/write-documentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Overview

There are two major types of documentation:

1. **docstrings**: your code's docstrings will be automatically parsed by the documentation sofware (`Sphinx <https://www.sphinx-doc.org>`_, more in `about shpinx`_).
1. **docstrings**: your code's docstrings will be automatically parsed by the documentation sofware (`Sphinx <https://www.sphinx-doc.org>`_, more in :ref:`about shpinx`).
2. **Manual** documentation such as this document. This can be for instance a detailed installation procedure, a tutorial, a FAQ, a contributor's guide etc. you name it!

**Both** are written in `ReStructured Text <https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html>`_ (``.rst``) format.
Expand Down Expand Up @@ -152,6 +152,7 @@ FAQ
- `Hover X Ref <https://sphinx-hoverxref.readthedocs.io/en/latest/index.html>`_ Enables tooltips to display contents on the hover of links
- `Napoleon <https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html>`_ enables the parsing of Google-style docstrings

.. _about shpinx:

About Sphinx
------------
Expand Down
1 change: 1 addition & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ sphinx-copybutton==0.5.1
sphinx-hoverxref==1.3.0
sphinxext-opengraph==0.8.2
sphinx-autoapi==3.0.0
sphinx-code-include==1.1.1
3 changes: 2 additions & 1 deletion gflownet/envs/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def plot_samples_topk(
k_top: int = 10,
n_rows: int = 2,
dpi: int = 150,
**kwargs,
):
"""
Plot tetris boards of top K samples.
Expand All @@ -543,7 +544,7 @@ def plot_samples_topk(
samples : list
List of terminating states sampled from the policy.
rewards : list
List of terminating states.
Rewards of the samples.
k_top : int
The number of samples that will be included in the plot. The k_top samples
with the highest reward are selected.
Expand Down
2 changes: 1 addition & 1 deletion gflownet/envs/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,6 @@ def test(
test_predictions[top_k_indices], self.y_test
)
for k, v in top_k_scores.items():
result[f"test_top_k_{k}"] = v
result[f"eval_top_k_{k}"] = v

return result
Loading
Loading