Skip to content

Commit

Permalink
allow init instantiation + more tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
vict0rsch committed Mar 4, 2024
1 parent cdb66b3 commit 473edf9
Show file tree
Hide file tree
Showing 7 changed files with 370 additions and 186 deletions.
2 changes: 2 additions & 0 deletions config/eval/base.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
_target_: gflownet.evaluator.base.GFlowNetEvaluator

# config formerly from logger.test
first_it: True
period: 100
Expand Down
106 changes: 102 additions & 4 deletions gflownet/evaluator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class methods should be used to instantiate this class.
class MyEvaluator(GFlowNetEvaluator):
def update_all_metrics_and_requirements(self):
'''
This method is called when the class is instantiated and is used to update
the global METRICS and ALL_REQS variables. It is used to define new metrics:
their display names (when logged) and requirements.
'''
global METRICS, ALL_REQS
METRICS["my_custom_metric"] = {
Expand All @@ -72,6 +77,28 @@ def update_all_metrics_and_requirements(self):
def my_custom_metric(self, some, arguments):
'''
Your metric-computing method. It should return a dict with two keys:
"metrics" and "data".
The "metrics" key should contain the new metric(s) and the "data" key
should contain the intermediate results that can be used to plot the
new metric(s).
Its arguments will come from the `eval()` method below.
Parameters
----------
some : type
description
arguments : type
description
Returns
-------
dict
A dict with two keys: "metrics" and "data".
'''
intermediate = some + arguments
return {
Expand All @@ -89,6 +116,29 @@ def my_custom_metric(self, some, arguments):
def my_custom_plot(
self, some_other=None, arguments=None, intermediate=None, **kwargs
):
'''
Your plotting method.
It should return a dict with figure titles as keys and the figures as
values.
Its arguments will come from the `plot()` method below, and basically come
from the "data" key of the output of other metrics-computing functions.
Parameters
----------
some_other : type, optional
description, by default None
arguments : type, optional
description, by default None
intermediate : type, optional
description, by default None
Returns
-------
dict
A dict with figure titles as keys and the figures as values.
'''
# whatever gets to **kwargs will be ignored, this is used to handle
# methods with varying signatures.
figs = {}
Expand All @@ -114,28 +164,76 @@ def my_custom_plot(
return figs
def plot(self, **kwargs):
'''
Your custom plot method.
It should return a dict with figure titles as keys and the figures as
values.
It will be called by the `eval_and_log` method to log the figures,
and given the "data" key of the output of other metrics-computing functions.
Returns
-------
dict
A dict with figure titles as keys and the figures as values.
'''
figs = super().plot(**kwargs)
figs.update(self.my_custom_plot(**kwargs))
return figs
def eval(self, metrics=None, **plot_kwargs):
gfn = self.gfn_agent
'''
Your custom eval method.
It should return a dict with two keys: "metrics" and "data".
It will be called by the `eval_and_log` method to log the metrics,
Parameters
----------
metrics : Union[list, dict], optional
The metrics you want to compute in this evaluation procedure,
by default None, meaning the ones defined in the config file.
Returns
-------
dict
A dict with two keys: "metrics" and "data".
'''
metrics = self.make_metrics(metrics)
reqs = self.make_requirements(metrics=metrics)
results = super().eval(metrics=metrics, **plot_kwargs)
if "new_req" in reqs:
some = self.gfn.sample_something()
arguments = utils.some_other_function()
my_results = self.my_custom_metric(some, arguments)
results["metrics"].update(my_results.get("metrics", {}))
results["data"].update(my_results.get("data", {}))
return results
In the previous example, the `update_all_metrics_and_requirements` method is used to
update the global `METRICS` and `ALL_REQS` variables. It will be called when the
`MyEvaluator` class is instantiated, in the init of `BaseEvaluator`.
Then define your own ``evaluator`` in the config file:
.. code-block:: yaml
# gflownet/config/evaluator/my_evaluator.yaml
defaults:
- base
_target_: gflownet.evaluator.my_evaluator.MyEvaluator
# any other params hereafter will extend or override the base class params:
period: 1000
In the previous example, the ``update_all_metrics_and_requirements`` method is used to
update the global ``METRICS`` and ``ALL_REQS`` variables. It will be called when the
``MyEvaluator`` class is instantiated, in the init of ``BaseEvaluator``.
By defining a new requirement, you ensure that the new metrics and plots will only be
computed if user asks for a metric that requires such computations.
Expand Down
Loading

0 comments on commit 473edf9

Please sign in to comment.