Skip to content

Commit

Permalink
Merge pull request #287 from alexhernandezgarcia/evaluator
Browse files Browse the repository at this point in the history
Refactor & standardize evaluation with `Evaluator`
  • Loading branch information
alexhernandezgarcia authored Jun 6, 2024
2 parents 72f961b + 6efaa19 commit 2321aa4
Show file tree
Hide file tree
Showing 39 changed files with 2,797 additions and 2,555 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ The repository supports logging of train and evaluation metrics to [wandb.ai](ht

Bibtex Format

```txt
```text
@misc{hernandez-garcia2024,
author = {Hernandez-Garcia, Alex and Saxena, Nikita and Volokhova, Alexandra and Koziarski, Michał and Sharma, Divya and Viviano, Joseph D and Carrier, Pierre Luc and Schmidt, Victor},
title = {gflownet},
Expand Down
26 changes: 26 additions & 0 deletions config/evaluator/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_target_: gflownet.evaluator.base.BaseEvaluator

# config formerly from logger.test
first_it: True
period: 100
n: 100
kde:
bandwidth: 0.1
kernel: gaussian
n_top_k: 5000
top_k: 100
top_k_period: -1
# Number of backward trajectories to estimate the log likelihood of each test data point
n_trajs_logprobs: 10
logprobs_batch_size: 100
logprobs_bootstrap_size: 10000
# Maximum number of test data points to compute log likelihood probs.
max_data_logprobs: 1e5
# Number of points to obtain a grid to estimate the reward density
n_grid: 40000
train_log_period: 1
checkpoints_period: 1000
# List of metrics as per gflownet/eval/evaluator.py:METRICS_NAMES
# Set to null for all of them
# Values must be comma separated like `metrics: "l1, kl, js"` (spaces are optional)
metrics: all
14 changes: 8 additions & 6 deletions config/experiments/icml23/ctorus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

defaults:
- override /env: ctorus
- override /evaluator: base
- override /gflownet: trajectorybalance
- override /proxy: torus
- override /logger: wandb
Expand Down Expand Up @@ -32,6 +33,12 @@ gflownet:
lr_z_mult: 1000
n_train_steps: 5000

# Evaluator
evaluator:
period: 25
n: 1000
checkpoints_period: 500

# Policy
policy:
forward:
Expand All @@ -50,15 +57,10 @@ policy:
logger:
lightweight: True
project_name: "Continuous GFlowNet"
tags:
tags:
- gflownet
- continuous
- ctorus
test:
period: 25
n: 1000
checkpoints:
period: 500

# Hydra
hydra:
Expand Down
13 changes: 7 additions & 6 deletions config/experiments/scrabble/jay.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
defaults:
- override /env: scrabble
- override /gflownet: trajectorybalance
- override /evaluator: base
- override /proxy: scrabble
- override /logger: wandb
- override /user: alex
Expand Down Expand Up @@ -50,21 +51,21 @@ policy:
shared_weights: False
checkpoint: backward

# Evaluator
period: 500
n: 1000
checkpoints_period: 500

# WandB
logger:
do:
online: true
lightweight: True
project_name: "scrabble"
tags:
tags:
- gflownet
- discrete
- scrabble
test:
period: 500
n: 1000
checkpoints:
period: 500

# Hydra
hydra:
Expand Down
6 changes: 1 addition & 5 deletions config/gflownet/gflownet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,4 @@ replay_sampling: permutation
# Train data set backward sampling
train_sampling: permutation
num_empirical_loss: 200000
oracle:
# Number of samples for oracle metrics
n: 500
sample_only: False
active_learning: False
use_context: False
33 changes: 0 additions & 33 deletions config/logger/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,6 @@ do:

project_name: "GFlowNet"

# Train metrics
train:
period: 1
# Test metrics
test:
first_it: True
period: 100
n: 100
kde:
bandwidth: 0.1
kernel: gaussian
n_top_k: 5000
top_k: 100
top_k_period: -1
# Number of backward trajectories to estimate the log likelihood of each test data point
n_trajs_logprobs: 10
logprobs_batch_size: 100
logprobs_bootstrap_size: 10000
# Maximum number of test data points to compute log likelihood probs.
max_data_logprobs: 1e5
# Number of points to obtain a grid to estimate the reward density
n_grid: 40000
# Oracle metrics
oracle:
period: 100000
k:
- 1
- 10
- 100
# Policy model checkpoints
checkpoints:
period: 1000

# Log dir
logdir:
root: ./logs
Expand Down
2 changes: 1 addition & 1 deletion config/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ defaults:

_target_: gflownet.utils.logger.Logger

tags:
tags:
- gflownet
1 change: 1 addition & 0 deletions config/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ defaults:
- proxy: corners
- logger: wandb
- user: default
- evaluator: base

# Device
device: cuda
Expand Down
3 changes: 2 additions & 1 deletion config/tests.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
defaults:
- _self_
- env: grid
- gflownet: trajectorybalance
- proxy: uniform
- policy: mlp
- logger: base
- user: alex
- evaluator: base
- _self_

# Device
device: cpu
Expand Down
17 changes: 12 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,13 @@
"sphinx_design",
"sphinx_copybutton",
"sphinxext.opengraph",
"code_include.extension",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]


# -- Options for HTML output -------------------------------------------------

Expand Down Expand Up @@ -122,6 +118,7 @@
# sphinx.ext.intersphinx
intersphinx_mapping = {
"torch": ("https://pytorch.org/docs/stable", None),
"omegaconf": ("https://omegaconf.readthedocs.io/en/latest", None),
}

# sphinx.ext.autodoc & autoapi.extension
Expand Down Expand Up @@ -179,3 +176,13 @@
"enable": True,
"image": "./_static/images/gflownet-logo.png",
}


# def skip_util_classes(app, what, name, obj, skip, options):
# return any(
# name.startswith(f"gflownet.{p}") for p in ["envs", "proxy", "policy", "utils"]
# )


# def setup(sphinx):
# sphinx.connect("autoapi-skip-member", skip_util_classes)
16 changes: 11 additions & 5 deletions docs/contributors/example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ Remember, this works in docstrings *and* in stand-alone ``.rst`` files.

Cool features:

Reference to a class: :class:`gflownet.proxy.crystals.dave.DAVE` (long), or another
:class:`~gflownet.gflownet.GFlowNetAgent` or to a method:
:meth:`~gflownet.gflownet.GFlowNetAgent.trajectorybalance_loss`
or to an external function :func:`torch.cuda.synchronize()`
(this <- needs to be listed in ``docs/conf.py:intersphinx_mapping``).
Reference code docs of:

- A class: :class:`gflownet.envs.grid.Grid` (long format)
- Another class :class:`~gflownet.gflownet.GFlowNetAgent` (short format, by prepending ``~``)
- A method :meth:`~gflownet.gflownet.GFlowNetAgent.trajectorybalance_loss`
- Or even an external function :func:`torch.cuda.synchronize()`

.. note
External content should be listed in ``docs/conf.py:intersphinx_mapping``.
More info in the `Read The Docs documentation <https://docs.readthedocs.io/en/stable/guides/intersphinx.html>`_.
An actual tutorial on ``.rst``:
`ReStructured Text for those who know Markdown <https://docs.open-mpi.org/en/v5.0.x/developers/rst-for-markdown-expats.html#hyperlinks-to-urls>`_
Expand Down
3 changes: 2 additions & 1 deletion docs/contributors/write-documentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Overview

There are two major types of documentation:

1. **docstrings**: your code's docstrings will be automatically parsed by the documentation sofware (`Sphinx <https://www.sphinx-doc.org>`_, more in `about shpinx`_).
1. **docstrings**: your code's docstrings will be automatically parsed by the documentation sofware (`Sphinx <https://www.sphinx-doc.org>`_, more in :ref:`about shpinx`).
2. **Manual** documentation such as this document. This can be for instance a detailed installation procedure, a tutorial, a FAQ, a contributor's guide etc. you name it!

**Both** are written in `ReStructured Text <https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html>`_ (``.rst``) format.
Expand Down Expand Up @@ -152,6 +152,7 @@ FAQ
- `Hover X Ref <https://sphinx-hoverxref.readthedocs.io/en/latest/index.html>`_ Enables tooltips to display contents on the hover of links
- `Napoleon <https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html>`_ enables the parsing of Google-style docstrings

.. _about shpinx:

About Sphinx
------------
Expand Down
1 change: 1 addition & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ sphinx-copybutton==0.5.1
sphinx-hoverxref==1.3.0
sphinxext-opengraph==0.8.2
sphinx-autoapi==3.0.0
sphinx-code-include==1.1.1
3 changes: 2 additions & 1 deletion gflownet/envs/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def plot_samples_topk(
k_top: int = 10,
n_rows: int = 2,
dpi: int = 150,
**kwargs,
):
"""
Plot tetris boards of top K samples.
Expand All @@ -543,7 +544,7 @@ def plot_samples_topk(
samples : list
List of terminating states sampled from the policy.
rewards : list
List of terminating states.
Rewards of the samples.
k_top : int
The number of samples that will be included in the plot. The k_top samples
with the highest reward are selected.
Expand Down
2 changes: 1 addition & 1 deletion gflownet/envs/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,6 @@ def test(
test_predictions[top_k_indices], self.y_test
)
for k, v in top_k_scores.items():
result[f"test_top_k_{k}"] = v
result[f"eval_top_k_{k}"] = v

return result
Loading

0 comments on commit 2321aa4

Please sign in to comment.