From bbfbf76e24e8b5efb46d4097d515f67c7d380be5 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 14 Feb 2024 14:29:36 -0500 Subject: [PATCH 001/106] v0: move configs and base API --- config/eval/base.yaml | 18 ++++++++++++ config/logger/base.yaml | 17 ----------- gflownet/evaluator/__init__.py | 0 gflownet/evaluator/base.py | 54 ++++++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 17 deletions(-) create mode 100644 config/eval/base.yaml create mode 100644 gflownet/evaluator/__init__.py create mode 100644 gflownet/evaluator/base.py diff --git a/config/eval/base.yaml b/config/eval/base.yaml new file mode 100644 index 000000000..4c749a94a --- /dev/null +++ b/config/eval/base.yaml @@ -0,0 +1,18 @@ +_target_: gflownet.evaluator.base.GFlowNetEvaluator + +# config formerly from logger.test +first_it: True +period: 100 +n: 100 +kde: + bandwidth: 0.1 + kernel: gaussian +n_top_k: 5000 +top_k: 100 +top_k_period: -1 +# Number of backward trajectories to estimate the log likelihood of each test data point +n_trajs_logprobs: 10 +logprobs_batch_size: 100 +logprobs_bootstrap_size: 10000 +# Maximum number of test data points to compute log likelihood probs. +max_data_logprobs: 1e5 \ No newline at end of file diff --git a/config/logger/base.yaml b/config/logger/base.yaml index a50b04479..a0abf661a 100644 --- a/config/logger/base.yaml +++ b/config/logger/base.yaml @@ -10,23 +10,6 @@ project_name: "GFlowNet" train: period: 1 # Test metrics -test: - first_it: True - period: 100 - n: 100 - kde: - bandwidth: 0.1 - kernel: gaussian - n_top_k: 5000 - top_k: 100 - top_k_period: -1 - # Number of backward trajectories to estimate the log likelihood of each test data point - n_trajs_logprobs: 10 - logprobs_batch_size: 100 - logprobs_bootstrap_size: 10000 - # Maximum number of test data points to compute log likelihood probs. - max_data_logprobs: 1e5 -# Oracle metrics oracle: period: 100000 k: diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py new file mode 100644 index 000000000..10a642233 --- /dev/null +++ b/gflownet/evaluator/base.py @@ -0,0 +1,54 @@ +import os +from typing import Union + +from gflownet.utils.common import load_gflow_net_from_run_path + +_sentinel = object() + + +class GFlowNetEvaluator: + def __init__(self, **kwargs): + if kwargs.get("sentinel") is not _sentinel: + raise NotImplementedError( + "Base evaluator class should not be instantiated. Use " + + "GFlowNetEvaluator.from_dir or GFlowNetEvaluator.from_agent methods." + ) + self.gfn_agent = kwargs.get("gfn_agent") + + @staticmethod + def from_dir( + path: Union[str, os.PathLike], + no_wandb: bool = True, + print_config: bool = False, + device: str = "cuda", + load_final_ckpt: bool = True, + ): + gfn_agent = load_gflow_net_from_run_path( + path, + no_wandb=no_wandb, + print_config=print_config, + device=device, + load_final_ckpt=load_final_ckpt, + ) + return GFlowNetEvaluator.from_agent(gfn_agent) + + @staticmethod + def from_agent(gfn_agent): + return GFlowNetEvaluator(gfn_agent=gfn_agent, sentinel=_sentinel) + + def plot(self): + print("Base evaluator plot method does not do anything.") + + def compute_metrics(self, metrics: list = []): + print("Base evaluator compute_metrics method does not do anything.") + + def evaluate(self, n_episodes: int = 1): + print("Base evaluator evaluate method does not do anything.") + + +if __name__ == "__main__": + # dev test case, will move to tests + gfn_run_dir = "/network/scratch/s/schmidtv/crystals/logs/icml24/crystalgfn/4074836/2024-01-27_20-54-55/5908fe41" + gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) + gfne.plot() + gfne.compute_metrics() From d56af9e31f6201f5626ed92c692df7331f2b2fc2 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Thu, 15 Feb 2024 18:40:54 -0500 Subject: [PATCH 002/106] add `eval_config` arg to `GFlowNetAgent` init --- gflownet/gflownet.py | 10 +++++----- main.py | 2 ++ tests/gflownet/envs/common.py | 1 + 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 85b8fa28e..595e85b8e 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -5,32 +5,28 @@ """ import copy -import os import pickle import time -from collections import defaultdict from pathlib import Path from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn -from scipy.special import logsumexp from torch.distributions import Bernoulli from tqdm import tqdm from gflownet.envs.base import GFlowNetEnv +from gflownet.evaluator.base import GFlowNetEvaluator from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer from gflownet.utils.common import ( - batch_with_rest, bootstrap_samples, set_device, set_float_precision, tbool, tfloat, tlong, - torch2np, ) @@ -57,6 +53,7 @@ def __init__( sample_only=False, replay_sampling="permutation", train_sampling="permutation", + eval_config=None, **kwargs, ): # Seed @@ -67,6 +64,9 @@ def __init__( self.float = set_float_precision(float_precision) # Environment self.env = env + # Evaluator + self.eval_config = eval_config + self.evaluator = GFlowNetEvaluator.from_agent(self) # Continuous environments self.continuous = hasattr(self.env, "continuous") and self.env.continuous if self.continuous and optimizer.loss in ["flowmatch", "flowmatching"]: diff --git a/main.py b/main.py index 25171afc6..94425e0a2 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ """ Runnable script with hydra capabilities """ + import os import pickle import random @@ -82,6 +83,7 @@ def main(config): state_flow=state_flow, buffer=config.env.buffer, logger=logger, + eval_config=config.eval, ) # Train GFlowNet diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index c499995dc..294743c3c 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -482,6 +482,7 @@ def test__gflownet_minimal_runs(self, n_repeat=1): backward_policy=backward_policy, buffer=config.env.buffer, logger=logger, + eval_config=config.eval, ) gflownet.train() assert True From 505b402e6eeb17deddc16cf7c6e0ff047f24383f Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Thu, 15 Feb 2024 18:42:35 -0500 Subject: [PATCH 003/106] WIP: towards evaluator --- config/eval/base.yaml | 9 +- config/logger/base.yaml | 2 +- gflownet/evaluator/base.py | 589 ++++++++++++++++++++++++++++++++++++- gflownet/gflownet.py | 315 +------------------- gflownet/utils/logger.py | 48 +-- 5 files changed, 607 insertions(+), 356 deletions(-) diff --git a/config/eval/base.yaml b/config/eval/base.yaml index 4c749a94a..9908967e5 100644 --- a/config/eval/base.yaml +++ b/config/eval/base.yaml @@ -1,5 +1,3 @@ -_target_: gflownet.evaluator.base.GFlowNetEvaluator - # config formerly from logger.test first_it: True period: 100 @@ -15,4 +13,9 @@ n_trajs_logprobs: 10 logprobs_batch_size: 100 logprobs_bootstrap_size: 10000 # Maximum number of test data points to compute log likelihood probs. -max_data_logprobs: 1e5 \ No newline at end of file +max_data_logprobs: 1e5 + +# List of metrics as per gflownet/eval/evaluator.py:METRICS_NAMES +# Set to null for all of them +# Values must be comma separated like `metrics: "l1, kl, js"` (spaces are optional) +metrics: null \ No newline at end of file diff --git a/config/logger/base.yaml b/config/logger/base.yaml index a0abf661a..8ab28fdfa 100644 --- a/config/logger/base.yaml +++ b/config/logger/base.yaml @@ -9,7 +9,7 @@ project_name: "GFlowNet" # Train metrics train: period: 1 -# Test metrics + oracle: period: 100000 k: diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 10a642233..a0bbf3af0 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -1,19 +1,220 @@ +import copy import os +import pickle +import time +from collections import defaultdict from typing import Union -from gflownet.utils.common import load_gflow_net_from_run_path +import numpy as np +import torch +from scipy.special import logsumexp + +from gflownet.utils.batch import Batch +from gflownet.utils.common import ( + batch_with_rest, + load_gflow_net_from_run_path, + tfloat, + torch2np, +) _sentinel = object() +METRICS_NAMES = { + "l1": "L1 error", + "kl": "KL Div.", + "jsd": "Jensen Shannon Div.", + "corr_prob_traj_rewards": "Corr. (test probs., rewards)", + "var_logrewards_logp": "Var(logR - logp) test", + "nll_tt": "NLL of test data", + "mean_logprobs_std": "Mean BS Std(logp)", + "mean_probs_std": "Mean BS Std(p)", + "logprobs_std_nll_ratio": "BS Std(logp) / NLL", +} + class GFlowNetEvaluator: def __init__(self, **kwargs): + """ + Base evaluator class for GFlowNetAgent. + + In charge of evaluating the GFlowNetAgent, computing metrics plotting figures + and optionally logging results using the GFlowNetAgent's logger. + + Only the `from_dir` and `from_agent` class methods should be used to instantiate + this class. + + Raises + ------ + NotImplementedError + If the `sentinel` keyword argument is not `_sentinel`, which is used to + prevent instantiation of the base class without using the `from_dir` or + `from_agent` class methods. + + """ if kwargs.get("sentinel") is not _sentinel: raise NotImplementedError( "Base evaluator class should not be instantiated. Use " + "GFlowNetEvaluator.from_dir or GFlowNetEvaluator.from_agent methods." ) self.gfn_agent = kwargs.get("gfn_agent") + self.config = self.gfn_agent.eval_config + self.logger = self.gfn_agent.logger + + self.set_metrics(self.config.metrics) + + def set_metrics(self, metrics=None): + """ + Set the metrics to be computed by the evaluator to the `self.metrics` attribute. + + If `None`, all metrics are computed. If a string, it can be a comma-separated + list of metric names, with or without spaces. All metrics must be in + `METRICS_NAMES`. + + Parameters + ---------- + metrics : (Union[str, List[str]], optional) + Metrics to compute when running the `evaluator.eval()` function. Defaults to + None, i.e. all metrics in `METRICS_NAMES` are computed. + + Raises + ------ + ValueError + If a metric name is not in `METRICS_NAMES`. + """ + if metrics is None: + metrics = METRICS_NAMES.keys() + if isinstance(metrics, str): + if "," in metrics: + metrics = [m.strip() for m in metrics.split(",")] + else: + metrics = [metrics] + for m in metrics: + if m not in METRICS_NAMES: + raise ValueError(f"Unknown metric name: {m}") + self.metrics = metrics + + def do_train(self, step): + """ + Check if training logs should be done at the current step. The decision is based + on the `self.config.train.period` attribute. + + Set `self.config.train.period` to `None` or a negative value to disable + training. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if training should be done at the current step, False otherwise. + """ + if self.config.train.period is None or self.config.train.period < 0: + return False + else: + return not step % self.config.train.period + + def do_test(self, step): + """ + Check if testing should be done at the current step. The decision is based on + the `self.config.test.period` attribute. + + Set `self.config.test.first_it` to `True` if testing should be done at the first + iteration step. Otherwise, testing will be done aftter `self.config.test.period` + steps. + + Set `self.config.test.period` to `None` or a negative value to disable testing. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if testing should be done at the current step, False otherwise. + """ + if self.config.test.period is None or self.config.test.period < 0: + return False + elif step == 1 and self.config.test.first_it: + return True + else: + return not step % self.config.test.period + + def do_top_k(self, step): + """ + Check if top k plots and metrics should be done at the current step. The + decision is based on the `self.config.test.top_k` and + `self.config.test.top_k_period` attributes. + + Set `self.config.test.top_k` to `None` or a negative value to disable top k + plots and metrics. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if top k plots and metrics should be done at the current step, False + """ + if self.config.test.top_k is None or self.config.test.top_k < 0: + return False + + if self.config.test.top_k_period is None or self.config.test.top_k_period < 0: + return False + + return step == 2 or step % self.config.test.top_k_period == 0 + + def do_oracle(self, step): + """ + Check if oracle should be done at the current step. The decision is based on the + `self.config.oracle.period` attribute. + + Set `self.config.oracle.period` to `None` or a negative value to disable oracle. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if oracle should be done at the current step, False otherwise. + """ + if self.config.oracle.period is None or self.config.oracle.period < 0: + return False + else: + return not step % self.oracle.period + + def do_checkpoints(self, step): + """ + Check if checkpoints should be done at the current step. The decision is based + on the `self.checkpoints.period` attribute. + + Set `self.checkpoints.period` to `None` or a negative value to disable + checkpoints. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if checkpoints should be done at the current step, False otherwise. + """ + if self.checkpoints.period is None or self.checkpoints.period < 0: + return False + else: + return not step % self.checkpoints.period @staticmethod def from_dir( @@ -23,6 +224,27 @@ def from_dir( device: str = "cuda", load_final_ckpt: bool = True, ): + """ + Instantiate a GFlowNetEvaluator from a run directory. + + Parameters + ---------- + path : Union[str, os.PathLike] + Path to the run directory from which to load the GFlowNetAgent. + no_wandb : bool, optional + Prevent wandb initialization, by default True + print_config : bool, optional + Whether or not to print the resulting (loaded) config, by default False + device : str, optional + Device to use for the instantiated GFlowNetAgent, by default "cuda" + load_final_ckpt : bool, optional + Use the latest possible checkpoint available in the path, by default True + + Returns + ------- + GFlowNetEvaluator + Instance of GFlowNetEvaluator with the GFlowNetAgent loaded from the run. + """ gfn_agent = load_gflow_net_from_run_path( path, no_wandb=no_wandb, @@ -34,21 +256,376 @@ def from_dir( @staticmethod def from_agent(gfn_agent): + """ + Instantiate a GFlowNetEvaluator from a GFlowNetAgent. + + Parameters + ---------- + gfn_agent : GFlowNetAgent + Instance of GFlowNetAgent to use for the GFlowNetEvaluator. + + Returns + ------- + GFlowNetEvaluator + Instance of GFlowNetEvaluator with the provided GFlowNetAgent. + """ + from gflownet.gflownet import GFlowNetAgent + + assert isinstance(gfn_agent, GFlowNetAgent), ( + "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " + + f"{type(gfn_agent)}." + ) + return GFlowNetEvaluator(gfn_agent=gfn_agent, sentinel=_sentinel) def plot(self): + """ + Plots this evaluator should do. This is a base method that does nothing and + should be overridden by subclasses. + """ print("Base evaluator plot method does not do anything.") - def compute_metrics(self, metrics: list = []): - print("Base evaluator compute_metrics method does not do anything.") + def eval(self, metrics=_sentinel, **plot_kwargs): + """ + Evaluate the GFlowNetAgent and compute metrics and plots. + + If `metrics` is not provided, the evaluator's `metrics` attribute is + used (default). + + Parameters + ---------- + metrics : _type_, optional + List of metrics to compute, by default the evaluator's `metrics` attribute. + plot_kwargs : dict, optional + Additional keyword arguments to pass to the plotting methods. + + Returns + ------- + list + List of computed metrics and figures: [l1, kl, jsd, corr_prob_traj_rewards, + var_logrewards_logp, nll_tt, mean_logprobs_std, mean_probs_std, + logprobs_std_nll_ratio, figs, env_metrics] (should be refactored to dict) + """ + gfn = self.gfn_agent + + if metrics is None: + # TODO-V use this in the rest of the code to selectively compute metrics + metrics = set(METRICS_NAMES.keys()) + + if gfn.buffer.test_pkl is None: + result = { + "metrics": { + k: getattr(gfn, k) if hasattr(gfn, k) else None for k in metrics + } + } + result["figs"] = (None,) + result["env_metrics"] = {} + return result.values() + + with open(gfn.buffer.test_pkl, "rb") as f: + dict_tt = pickle.load(f) + x_tt = dict_tt["x"] + + # Compute correlation between the rewards of the test data and the log + # likelihood of the data according the the GFlowNet policy; and NLL. + # TODO: organise code for better efficiency and readability + logprobs_x_tt, logprobs_std, probs_std = gfn.estimate_logprobs_data( + x_tt, + n_trajectories=self.logger.test.n_trajs_logprobs, + max_data_size=self.logger.test.max_data_logprobs, + batch_size=self.logger.test.logprobs_batch_size, + bs_num_samples=self.logger.test.logprobs_bootstrap_size, + ) + mean_logprobs_std = logprobs_std.mean().item() + mean_probs_std = probs_std.mean().item() + rewards_x_tt = gfn.env.reward_batch(x_tt) + corr_prob_traj_rewards = np.corrcoef( + np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt + )[0, 1] + var_logrewards_logp = torch.var( + torch.log(tfloat(rewards_x_tt, float_type=gfn.float, device=gfn.device)) + - logprobs_x_tt + ).item() + nll_tt = -logprobs_x_tt.mean().item() + logprobs_std_nll_ratio = torch.mean(-logprobs_std / logprobs_x_tt).item() + + x_sampled = [] + if gfn.buffer.test_type is not None and gfn.buffer.test_type == "all": + batch, _ = gfn.sample_batch(n_forward=self.logger.test.n, train=False) + assert batch.is_valid() + x_sampled = batch.get_terminating_states() + + if "density_true" in dict_tt: + density_true = dict_tt["density_true"] + else: + rewards = gfn.env.reward_batch(x_tt) + z_true = rewards.sum() + density_true = rewards / z_true + with open(gfn.buffer.test_pkl, "wb") as f: + dict_tt["density_true"] = density_true + pickle.dump(dict_tt, f) + hist = defaultdict(int) + for x in x_sampled: + hist[tuple(x)] += 1 + z_pred = sum([hist[tuple(x)] for x in x_tt]) + 1e-9 + density_pred = np.array([hist[tuple(x)] / z_pred for x in x_tt]) + log_density_true = np.log(density_true + 1e-8) + log_density_pred = np.log(density_pred + 1e-8) + elif gfn.continuous and hasattr(gfn.env, "fit_kde"): + batch, _ = gfn.sample_batch(n_forward=self.logger.test.n, train=False) + assert batch.is_valid() + x_sampled = batch.get_terminating_states() + # TODO make it work with conditional env + x_sampled = torch2np(gfn.env.states2proxy(x_sampled)) + x_tt = torch2np(gfn.env.states2proxy(x_tt)) + kde_pred = gfn.env.fit_kde( + x_sampled, + kernel=self.logger.test.kde.kernel, + bandwidth=self.logger.test.kde.bandwidth, + ) + if "log_density_true" in dict_tt and "kde_true" in dict_tt: + log_density_true = dict_tt["log_density_true"] + kde_true = dict_tt["kde_true"] + else: + # Sample from reward via rejection sampling + x_from_reward = gfn.env.sample_from_reward(n_samples=self.logger.test.n) + x_from_reward = torch2np(gfn.env.states2proxy(x_from_reward)) + # Fit KDE with samples from reward + kde_true = gfn.env.fit_kde( + x_from_reward, + kernel=self.logger.test.kde.kernel, + bandwidth=self.logger.test.kde.bandwidth, + ) + # Estimate true log density using test samples + # TODO: this may be specific-ish for the torus or not + scores_true = kde_true.score_samples(x_tt) + log_density_true = scores_true - logsumexp(scores_true, axis=0) + # Add log_density_true and kde_true to pickled test dict + with open(gfn.buffer.test_pkl, "wb") as f: + dict_tt["log_density_true"] = log_density_true + dict_tt["kde_true"] = kde_true + pickle.dump(dict_tt, f) + # Estimate pred log density using test samples + # TODO: this may be specific-ish for the torus or not + scores_pred = kde_pred.score_samples(x_tt) + log_density_pred = scores_pred - logsumexp(scores_pred, axis=0) + density_true = np.exp(log_density_true) + density_pred = np.exp(log_density_pred) + else: + # TODO: refactor + env_metrics = gfn.env.test(x_sampled) + return { + "metrics": { + "l1": gfn.l1, + "kl": gfn.kl, + "jsd": gfn.jsd, + "corr_prob_traj_rewards": corr_prob_traj_rewards, + "var_logrewards_logp": var_logrewards_logp, + "nll_tt": nll_tt, + "mean_logprobs_std": mean_logprobs_std, + "mean_probs_std": mean_probs_std, + "logprobs_std_nll_ratio": logprobs_std_nll_ratio, + }, + "figs": (None,), + "env_metrics": env_metrics, + } + # L1 error + l1 = np.abs(density_pred - density_true).mean() + # KL divergence + kl = (density_true * (log_density_true - log_density_pred)).mean() + # Jensen-Shannon divergence + log_mean_dens = np.logaddexp(log_density_true, log_density_pred) + np.log(0.5) + jsd = 0.5 * np.sum(density_true * (log_density_true - log_mean_dens)) + jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) + + # Plots + # TODO-V: move to evaluator.plot()? + if hasattr(gfn.env, "plot_reward_samples"): + fig_reward_samples = gfn.env.plot_reward_samples(x_sampled, **plot_kwargs) + else: + fig_reward_samples = None + if hasattr(gfn.env, "plot_kde"): + fig_kde_pred = gfn.env.plot_kde(kde_pred, **plot_kwargs) + fig_kde_true = gfn.env.plot_kde(kde_true, **plot_kwargs) + else: + fig_kde_pred = None + fig_kde_true = None + + return { + "metrics": { + "l1": l1, + "kl": kl, + "jsd": jsd, + "corr_prob_traj_rewards": corr_prob_traj_rewards, + "var_logrewards_logp": var_logrewards_logp, + "nll_tt": nll_tt, + "mean_logprobs_std": mean_logprobs_std, + "mean_probs_std": mean_probs_std, + "logprobs_std_nll_ratio": logprobs_std_nll_ratio, + }, + "figs": { + "True reward and GFlowNet samples": fig_reward_samples, + "GFlowNet KDE Policy": fig_kde_pred, + "Reward KDE": fig_kde_true, + }, + "env_metrics": {}, + } + + @torch.no_grad() + def eval_top_k(self, it, gfn_states=None, random_states=None): + """ + Sample from the current GFN and compute metrics and plots for the top k states + according to both the energy and the reward. + + Parameters + ---------- + it : int + current iteration + gfn_states : list, optional + Already sampled gfn states. Defaults to None. + random_states : list, optional + Already sampled random states. Defaults to None. + + Returns + ------- + tuple[dict, dict[str, plt.Figure], dict] + Computed dict of metrics, and figures (as {str: plt.Figure}), and optionally + (only once) summary metrics. + """ + # only do random top k plots & metrics once + do_random = it // self.logger.test.top_k_period == 1 + duration = None + summary = {} + prob = copy.deepcopy(self.random_action_prob) + print() + if not gfn_states: + # sample states from the current gfn + batch = Batch(env=self.env, device=self.device, float_type=self.float) + self.random_action_prob = 0 + t = time.time() + print("Sampling from GFN...", end="\r") + for b in batch_with_rest( + 0, self.logger.test.n_top_k, self.batch_size_total + ): + sub_batch, _ = self.sample_batch(n_forward=len(b), train=False) + batch.merge(sub_batch) + duration = time.time() - t + gfn_states = batch.get_terminating_states() - def evaluate(self, n_episodes: int = 1): - print("Base evaluator evaluate method does not do anything.") + # compute metrics and get plots + print("[eval_top_k] Making GFN plots...", end="\r") + metrics, figs, fig_names = self.env.top_k_metrics_and_plots( + gfn_states, self.logger.test.top_k, name="gflownet", step=it + ) + if duration: + metrics["gflownet top k sampling duration"] = duration + + if do_random: + # sample random states from uniform actions + if not random_states: + batch = Batch(env=self.env, device=self.device, float_type=self.float) + self.random_action_prob = 1.0 + print("[eval_top_k] Sampling at random...", end="\r") + for b in batch_with_rest( + 0, self.logger.test.n_top_k, self.batch_size_total + ): + sub_batch, _ = self.sample_batch(n_forward=len(b), train=False) + batch.merge(sub_batch) + # compute metrics and get plots + random_states = batch.get_terminating_states() + print("[eval_top_k] Making Random plots...", end="\r") + ( + random_metrics, + random_figs, + random_fig_names, + ) = self.env.top_k_metrics_and_plots( + random_states, self.logger.test.top_k, name="random", step=None + ) + # add to current metrics and plots + summary.update(random_metrics) + figs += random_figs + fig_names += random_fig_names + # compute training data metrics and get plots + print("[eval_top_k] Making train plots...", end="\r") + ( + train_metrics, + train_figs, + train_fig_names, + ) = self.env.top_k_metrics_and_plots( + None, self.logger.test.top_k, name="train", step=None + ) + # add to current metrics and plots + summary.update(train_metrics) + figs += train_figs + fig_names += train_fig_names + + self.random_action_prob = prob + + print(" " * 100, end="\r") + print("eval_top_k metrics:") + max_k = max([len(k) for k in (list(metrics.keys()) + list(summary.keys()))]) + 1 + print( + " • " + + "\n • ".join( + f"{k:{max_k}}: {v:.4f}" + for k, v in (list(metrics.items()) + list(summary.items())) + ) + ) + print() + + figs = {f: n for f, n in zip(figs, fig_names)} + + return metrics, figs, summary + + def eval_and_log(self, it, metrics=_sentinel): + """ + Evaluate the GFlowNetAgent and log the results with its logger. + + Will call `self.eval()` and log the results using the GFlowNetAgent's logger + `log_metrics()` and `log_plots()` methods. + + Parameters + ---------- + it : int + Current iteration step. + metrics : Union[str, List[str]], optional + List of metrics to compute, by default the evaluator's `metrics` attribute. + """ + gfn = self.gfn_agent + # TODO-V: do we need to set attributes? + result = self.eval(metrics=metrics) + for m, v in result["metrics"].items(): + setattr(gfn, m, v) + + self.logger.log_test_metrics(*result["metrics"].values(), it, gfn.use_context) + self.logger.log_metrics(result["env_metrics"], it, use_context=gfn.use_context) + self.logger.log_plots(result["figs"], it, use_context=gfn.use_context) + + def eval_and_log_top_k(self, it): + """ + Evaluate the GFlowNetAgent's top k samples performance and log the results with + its logger. + + Parameters + ---------- + it : int + Current iteration step, by default None. + """ + + metrics, figs, summary = self.eval_top_k(it) + self.logger.log_plots(figs, it, use_context=self.use_context) + self.logger.log_metrics(metrics, use_context=self.use_context, step=it) + self.logger.log_summary(summary) if __name__ == "__main__": # dev test case, will move to tests - gfn_run_dir = "/network/scratch/s/schmidtv/crystals/logs/icml24/crystalgfn/4074836/2024-01-27_20-54-55/5908fe41" + from pathlib import Path + + scratch = Path(os.environ["SCRATCH"]) + run_dirs = scratch / "crystals/logs/icml24/crystalgfn" + gfn_run_dir = run_dirs / "4074836/2024-01-27_20-54-55/5908fe41" + gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) gfne.plot() gfne.compute_metrics() diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 595e85b8e..a4eb0ab9c 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -978,50 +978,12 @@ def train(self): # Train loop pbar = tqdm(range(1, self.n_train_steps + 1), disable=not self.logger.progress) for it in pbar: - # Test - fig_names = [ - "True reward and GFlowNet samples", - "GFlowNet KDE Policy", - "Reward KDE", - ] - if self.logger.do_test(it): - ( - self.l1, - self.kl, - self.jsd, - self.corr_prob_traj_rewards, - self.var_logrewards_logp, - self.nll_tt, - self.mean_logprobs_std, - self.mean_probs_std, - self.logprobs_std_nll_ratio, - figs, - env_metrics, - ) = self.test() - self.logger.log_test_metrics( - self.l1, - self.kl, - self.jsd, - self.corr_prob_traj_rewards, - self.var_logrewards_logp, - self.nll_tt, - self.mean_logprobs_std, - self.mean_probs_std, - self.logprobs_std_nll_ratio, - it, - self.use_context, - ) - self.logger.log_metrics(env_metrics, it, use_context=self.use_context) - self.logger.log_plots( - figs, it, fig_names=fig_names, use_context=self.use_context - ) - if self.logger.do_top_k(it): - metrics, figs, fig_names, summary = self.test_top_k(it) - self.logger.log_plots( - figs, it, use_context=self.use_context, fig_names=fig_names - ) - self.logger.log_metrics(metrics, use_context=self.use_context, step=it) - self.logger.log_summary(summary) + # Test and log + if self.evaluator.do_test(it): + self.evaluator.eval_and_log(it) + if self.evaluator.do_top_k(it): + self.evaluator.eval_and_log_top_k(it) + t0_iter = time.time() batch = Batch(env=self.env, device=self.device, float_type=self.float) for j in range(self.sttr): @@ -1145,271 +1107,6 @@ def train(self): if self.use_context is False: self.logger.end() - def test(self, **plot_kwargs): - """ - Computes metrics by sampling trajectories from the forward policy. - """ - if self.buffer.test_pkl is None: - return ( - self.l1, - self.kl, - self.jsd, - self.corr_prob_traj_rewards, - self.var_logrewards_logp, - self.nll_tt, - self.mean_logprobs_std, - self.mean_probs_std, - self.logprobs_std_nll_ratio, - (None,), - {}, - ) - with open(self.buffer.test_pkl, "rb") as f: - dict_tt = pickle.load(f) - x_tt = dict_tt["x"] - - # Compute correlation between the rewards of the test data and the log - # likelihood of the data according the the GFlowNet policy; and NLL. - # TODO: organise code for better efficiency and readability - logprobs_x_tt, logprobs_std, probs_std = self.estimate_logprobs_data( - x_tt, - n_trajectories=self.logger.test.n_trajs_logprobs, - max_data_size=self.logger.test.max_data_logprobs, - batch_size=self.logger.test.logprobs_batch_size, - bs_num_samples=self.logger.test.logprobs_bootstrap_size, - ) - mean_logprobs_std = logprobs_std.mean().item() - mean_probs_std = probs_std.mean().item() - rewards_x_tt = self.env.reward_batch(x_tt) - corr_prob_traj_rewards = np.corrcoef( - np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt - )[0, 1] - var_logrewards_logp = torch.var( - torch.log(tfloat(rewards_x_tt, float_type=self.float, device=self.device)) - - logprobs_x_tt - ).item() - nll_tt = -logprobs_x_tt.mean().item() - logprobs_std_nll_ratio = torch.mean(-logprobs_std / logprobs_x_tt).item() - - x_sampled = [] - if self.buffer.test_type is not None and self.buffer.test_type == "all": - batch, _ = self.sample_batch(n_forward=self.logger.test.n, train=False) - assert batch.is_valid() - x_sampled = batch.get_terminating_states() - - if "density_true" in dict_tt: - density_true = dict_tt["density_true"] - else: - rewards = self.env.reward_batch(x_tt) - z_true = rewards.sum() - density_true = rewards / z_true - with open(self.buffer.test_pkl, "wb") as f: - dict_tt["density_true"] = density_true - pickle.dump(dict_tt, f) - hist = defaultdict(int) - for x in x_sampled: - hist[tuple(x)] += 1 - z_pred = sum([hist[tuple(x)] for x in x_tt]) + 1e-9 - density_pred = np.array([hist[tuple(x)] / z_pred for x in x_tt]) - log_density_true = np.log(density_true + 1e-8) - log_density_pred = np.log(density_pred + 1e-8) - elif self.continuous and hasattr(self.env, "fit_kde"): - batch, _ = self.sample_batch(n_forward=self.logger.test.n, train=False) - assert batch.is_valid() - x_sampled = batch.get_terminating_states() - # TODO make it work with conditional env - x_sampled = torch2np(self.env.states2proxy(x_sampled)) - x_tt = torch2np(self.env.states2proxy(x_tt)) - kde_pred = self.env.fit_kde( - x_sampled, - kernel=self.logger.test.kde.kernel, - bandwidth=self.logger.test.kde.bandwidth, - ) - if "log_density_true" in dict_tt and "kde_true" in dict_tt: - log_density_true = dict_tt["log_density_true"] - kde_true = dict_tt["kde_true"] - else: - # Sample from reward via rejection sampling - x_from_reward = self.env.sample_from_reward( - n_samples=self.logger.test.n - ) - x_from_reward = torch2np(self.env.states2proxy(x_from_reward)) - # Fit KDE with samples from reward - kde_true = self.env.fit_kde( - x_from_reward, - kernel=self.logger.test.kde.kernel, - bandwidth=self.logger.test.kde.bandwidth, - ) - # Estimate true log density using test samples - # TODO: this may be specific-ish for the torus or not - scores_true = kde_true.score_samples(x_tt) - log_density_true = scores_true - logsumexp(scores_true, axis=0) - # Add log_density_true and kde_true to pickled test dict - with open(self.buffer.test_pkl, "wb") as f: - dict_tt["log_density_true"] = log_density_true - dict_tt["kde_true"] = kde_true - pickle.dump(dict_tt, f) - # Estimate pred log density using test samples - # TODO: this may be specific-ish for the torus or not - scores_pred = kde_pred.score_samples(x_tt) - log_density_pred = scores_pred - logsumexp(scores_pred, axis=0) - density_true = np.exp(log_density_true) - density_pred = np.exp(log_density_pred) - else: - # TODO: refactor - env_metrics = self.env.test(x_sampled) - return ( - self.l1, - self.kl, - self.jsd, - corr_prob_traj_rewards, - var_logrewards_logp, - nll_tt, - mean_logprobs_std, - mean_probs_std, - logprobs_std_nll_ratio, - (None,), - env_metrics, - ) - # L1 error - l1 = np.abs(density_pred - density_true).mean() - # KL divergence - kl = (density_true * (log_density_true - log_density_pred)).mean() - # Jensen-Shannon divergence - log_mean_dens = np.logaddexp(log_density_true, log_density_pred) + np.log(0.5) - jsd = 0.5 * np.sum(density_true * (log_density_true - log_mean_dens)) - jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) - - # Plots - - if hasattr(self.env, "plot_reward_samples"): - fig_reward_samples = self.env.plot_reward_samples(x_sampled, **plot_kwargs) - else: - fig_reward_samples = None - if hasattr(self.env, "plot_kde"): - fig_kde_pred = self.env.plot_kde(kde_pred, **plot_kwargs) - fig_kde_true = self.env.plot_kde(kde_true, **plot_kwargs) - else: - fig_kde_pred = None - fig_kde_true = None - return ( - l1, - kl, - jsd, - corr_prob_traj_rewards, - var_logrewards_logp, - nll_tt, - mean_logprobs_std, - mean_probs_std, - logprobs_std_nll_ratio, - [fig_reward_samples, fig_kde_pred, fig_kde_true], - {}, - ) - - @torch.no_grad() - def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): - """ - Sample from the current GFN and compute metrics and plots for the top k states - according to both the energy and the reward. - - Parameters - ---------- - it : int - Current iteration. - progress : bool, optional - Print sampling progress. Defaults to False. - gfn_states : list, optional - Already sampled gfn states. Defaults to None. - random_states : list, optional - Already sampled random states. Defaults to None. - - Returns - ------- - tuple[dict, list[plt.Figure], list[str], dict] - Computed dict of metrics, and figures, their names and optionally (only - once) summary metrics. - """ - # only do random top k plots & metrics once - do_random = it // self.logger.test.top_k_period == 1 - duration = None - summary = {} - prob = copy.deepcopy(self.random_action_prob) - print() - if not gfn_states: - # sample states from the current gfn - batch = Batch(env=self.env, device=self.device, float_type=self.float) - self.random_action_prob = 0 - t = time.time() - print("Sampling from GFN...", end="\r") - for b in batch_with_rest( - 0, self.logger.test.n_top_k, self.batch_size_total - ): - sub_batch, _ = self.sample_batch(n_forward=len(b), train=False) - batch.merge(sub_batch) - duration = time.time() - t - gfn_states = batch.get_terminating_states() - - # compute metrics and get plots - print("[test_top_k] Making GFN plots...", end="\r") - metrics, figs, fig_names = self.env.top_k_metrics_and_plots( - gfn_states, self.logger.test.top_k, name="gflownet", step=it - ) - if duration: - metrics["gflownet top k sampling duration"] = duration - - if do_random: - # sample random states from uniform actions - if not random_states: - batch = Batch(env=self.env, device=self.device, float_type=self.float) - self.random_action_prob = 1.0 - print("[test_top_k] Sampling at random...", end="\r") - for b in batch_with_rest( - 0, self.logger.test.n_top_k, self.batch_size_total - ): - sub_batch, _ = self.sample_batch(n_forward=len(b), train=False) - batch.merge(sub_batch) - # compute metrics and get plots - random_states = batch.get_terminating_states() - print("[test_top_k] Making Random plots...", end="\r") - ( - random_metrics, - random_figs, - random_fig_names, - ) = self.env.top_k_metrics_and_plots( - random_states, self.logger.test.top_k, name="random", step=None - ) - # add to current metrics and plots - summary.update(random_metrics) - figs += random_figs - fig_names += random_fig_names - # compute training data metrics and get plots - print("[test_top_k] Making train plots...", end="\r") - ( - train_metrics, - train_figs, - train_fig_names, - ) = self.env.top_k_metrics_and_plots( - None, self.logger.test.top_k, name="train", step=None - ) - # add to current metrics and plots - summary.update(train_metrics) - figs += train_figs - fig_names += train_fig_names - - self.random_action_prob = prob - - print(" " * 100, end="\r") - print("test_top_k metrics:") - max_k = max([len(k) for k in (list(metrics.keys()) + list(summary.keys()))]) + 1 - print( - " • " - + "\n • ".join( - f"{k:{max_k}}: {v:.4f}" - for k, v in (list(metrics.items()) + list(summary.items())) - ) - ) - print() - return metrics, figs, fig_names, summary - def get_log_corr(self, times): data_logq = [] times.update( diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index fdc075f4e..3ca8cfd46 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -1,6 +1,7 @@ import os from datetime import datetime from pathlib import Path +from typing import Union import matplotlib.pyplot as plt import numpy as np @@ -83,41 +84,6 @@ def __init__( # Write wandb URL self.write_url_file() - def do_train(self, step): - if self.train.period is None or self.train.period < 0: - return False - else: - return not step % self.train.period - - def do_test(self, step): - if self.test.period is None or self.test.period < 0: - return False - elif step == 1 and self.test.first_it: - return True - else: - return not step % self.test.period - - def do_top_k(self, step): - if self.test.top_k is None or self.test.top_k < 0: - return False - - if self.test.top_k_period is None or self.test.top_k_period < 0: - return False - - return step == 2 or step % self.test.top_k_period == 0 - - def do_oracle(self, step): - if self.oracle.period is None or self.oracle.period < 0: - return False - else: - return not step % self.oracle.period - - def do_checkpoints(self, step): - if self.checkpoints.period is None or self.checkpoints.period < 0: - return False - else: - return not step % self.checkpoints.period - def write_url_file(self): if self.wandb is not None: self.url = self.wandb.run.get_url() @@ -181,17 +147,24 @@ def log_histogram(self, key, value, step, use_context=True): fig = self.wandb.Image(fig) self.wandb.log({key: fig}, step) - def log_plots(self, figs: list, step, fig_names=None, use_context=True): + def log_plots(self, figs: Union[dict, list], step, use_context=True): if not self.do.online: self.close_figs(figs) return - keys = fig_names or [f"Figure {i} at step {step}" for i in range(len(figs))] + if isinstance(figs, dict): + keys = figs.keys() + figs = list(figs.values()) + else: + assert isinstance(figs, list), "figs must be a list or a dict" + keys = [f"Figure {i} at step {step}" for i in range(len(figs))] + for key, fig in zip(keys, figs): if use_context: # fixme context = self.context + "/" + key if fig is not None: figimg = self.wandb.Image(fig) self.wandb.log({key: figimg}, step) + self.close_figs(figs) def close_figs(self, figs: list): @@ -301,6 +274,7 @@ def log_sampler_test( ) def log_sampler_oracle(self, energies: array, step: int, use_context: bool): + # TODO-V -> remove? Unused if not self.do.online: return if step.do_oracle(step): From a856791022e9159fc58d29e50a67bbe952161dbb Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Thu, 15 Feb 2024 18:42:50 -0500 Subject: [PATCH 004/106] rename to `eval_top_k` --- gflownet/envs/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/tree.py b/gflownet/envs/tree.py index 13684f4a3..3895f336c 100644 --- a/gflownet/envs/tree.py +++ b/gflownet/envs/tree.py @@ -1567,6 +1567,6 @@ def test( test_predictions[top_k_indices], self.y_test ) for k, v in top_k_scores.items(): - result[f"test_top_k_{k}"] = v + result[f"eval_top_k_{k}"] = v return result From 3216ae91763f222b20a3df67585663a125d11784 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Thu, 15 Feb 2024 18:43:12 -0500 Subject: [PATCH 005/106] Update for new `evaluator.eval()` api --- scripts/eval_gflownet.py | 43 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/scripts/eval_gflownet.py b/scripts/eval_gflownet.py index f5d4b6c31..1fda71ebd 100644 --- a/scripts/eval_gflownet.py +++ b/scripts/eval_gflownet.py @@ -128,6 +128,18 @@ def set_device(device: str): return torch.device("cpu") +def path_compatible(str): + """ + Replace all non-alphanumeric characters with underscores + + Parameters + ---------- + str : str + The string to be made compatible + """ + return "".join([c if c.isalnum() else "_" for c in str]) + + def main(args): if args.randominit: prefix = "randominit" @@ -153,37 +165,26 @@ def main(args): if not args.samples_only: gflownet.logger.test.n = args.n_samples - ( - l1, - kl, - jsd, - corr_prob_traj_rew, - var_logrew_logp, - nll, - figs, - env_metrics, - ) = gflownet.test() - # Save figures - keys = ["True reward and GFlowNet samples", "GFlowNet KDE Policy", "Reward KDE"] - fignames = ["samples", "kde_gfn", "kde_reward"] + eval_results = gflownet.evaluator.eval() + + # TODO-V: legacy -> ok to remove? + # keys = ["True reward and GFlowNet samples", "GFlowNet KDE Policy", "Reward KDE"] + # fignames = ["samples", "kde_gfn", "kde_reward"] output_dir = base_dir / "figures" print("output_dir: ", str(output_dir)) output_dir.mkdir(parents=True, exist_ok=True) - for fig, figname in zip(figs, fignames): - output_fig = output_dir / figname + for k, (figname, fig) in enumerate(eval_results["figs"].items()): + output_fig = output_dir / (path_compatible(figname) + ".pdf") if fig is not None: fig.savefig(output_fig, bbox_inches="tight") print(f"Saved figures to {output_dir}") # Print metrics - print(f"L1: {l1}") - print(f"KL: {kl}") - print(f"JSD: {jsd}") - print(f"Corr (exp(logp), rewards): {corr_prob_traj_rew}") - print(f"Var (log(R) - logp): {var_logrew_logp}") - print(f"NLL: {nll}") + print("Metrics:") + for k, v in eval_results["metrics"].items(): + print(f"\t{k}: {v:.4f}") # ------------------------------------------ # ----- Sample GFlowNet ----- From 294e9b5ed4c05a3deb309d564dc83d6eb1d4cd1c Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 16 Feb 2024 14:56:33 -0500 Subject: [PATCH 006/106] quote in print --- gflownet/evaluator/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index a0bbf3af0..f4d25f7e6 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -283,7 +283,7 @@ def plot(self): Plots this evaluator should do. This is a base method that does nothing and should be overridden by subclasses. """ - print("Base evaluator plot method does not do anything.") + print("Base evaluator `plot()` method does not do anything.") def eval(self, metrics=_sentinel, **plot_kwargs): """ From a24efa38ae03fbc13b36e0d5479c3405730e285c Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 16 Feb 2024 14:56:57 -0500 Subject: [PATCH 007/106] no figs as empty dicts instead of `(None,)` --- gflownet/evaluator/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index f4d25f7e6..fdf8dd993 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -318,7 +318,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): k: getattr(gfn, k) if hasattr(gfn, k) else None for k in metrics } } - result["figs"] = (None,) + result["figs"] = {} result["env_metrics"] = {} return result.values() @@ -426,7 +426,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): "mean_probs_std": mean_probs_std, "logprobs_std_nll_ratio": logprobs_std_nll_ratio, }, - "figs": (None,), + "figs": {}, "env_metrics": env_metrics, } # L1 error From 70252e5b5b7cae2c67456450ef776f4b9a38d18f Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 16 Feb 2024 14:57:14 -0500 Subject: [PATCH 008/106] fix `self` to `gfn` in `eval_top_k` --- gflownet/evaluator/base.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index fdf8dd993..d59576141 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -497,25 +497,24 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): duration = None summary = {} prob = copy.deepcopy(self.random_action_prob) + gfn = self.gfn_agent print() if not gfn_states: # sample states from the current gfn - batch = Batch(env=self.env, device=self.device, float_type=self.float) - self.random_action_prob = 0 + batch = Batch(env=gfn.env, device=gfn.device, float_type=gfn.float) + gfn.random_action_prob = 0 t = time.time() print("Sampling from GFN...", end="\r") - for b in batch_with_rest( - 0, self.logger.test.n_top_k, self.batch_size_total - ): - sub_batch, _ = self.sample_batch(n_forward=len(b), train=False) + for b in batch_with_rest(0, gfn.logger.test.n_top_k, gfn.batch_size_total): + sub_batch, _ = gfn.sample_batch(n_forward=len(b), train=False) batch.merge(sub_batch) duration = time.time() - t gfn_states = batch.get_terminating_states() # compute metrics and get plots print("[eval_top_k] Making GFN plots...", end="\r") - metrics, figs, fig_names = self.env.top_k_metrics_and_plots( - gfn_states, self.logger.test.top_k, name="gflownet", step=it + metrics, figs, fig_names = gfn.env.top_k_metrics_and_plots( + gfn_states, gfn.logger.test.top_k, name="gflownet", step=it ) if duration: metrics["gflownet top k sampling duration"] = duration @@ -523,13 +522,13 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): if do_random: # sample random states from uniform actions if not random_states: - batch = Batch(env=self.env, device=self.device, float_type=self.float) - self.random_action_prob = 1.0 + batch = Batch(env=gfn.env, device=gfn.device, float_type=gfn.float) + gfn.random_action_prob = 1.0 print("[eval_top_k] Sampling at random...", end="\r") for b in batch_with_rest( - 0, self.logger.test.n_top_k, self.batch_size_total + 0, gfn.logger.test.n_top_k, gfn.batch_size_total ): - sub_batch, _ = self.sample_batch(n_forward=len(b), train=False) + sub_batch, _ = gfn.sample_batch(n_forward=len(b), train=False) batch.merge(sub_batch) # compute metrics and get plots random_states = batch.get_terminating_states() @@ -538,8 +537,8 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): random_metrics, random_figs, random_fig_names, - ) = self.env.top_k_metrics_and_plots( - random_states, self.logger.test.top_k, name="random", step=None + ) = gfn.env.top_k_metrics_and_plots( + random_states, gfn.logger.test.top_k, name="random", step=None ) # add to current metrics and plots summary.update(random_metrics) @@ -551,15 +550,15 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): train_metrics, train_figs, train_fig_names, - ) = self.env.top_k_metrics_and_plots( - None, self.logger.test.top_k, name="train", step=None + ) = gfn.env.top_k_metrics_and_plots( + None, gfn.logger.test.top_k, name="train", step=None ) # add to current metrics and plots summary.update(train_metrics) figs += train_figs fig_names += train_fig_names - self.random_action_prob = prob + gfn.random_action_prob = prob print(" " * 100, end="\r") print("eval_top_k metrics:") From 9769016e51c94b3a667a06dcd4ac11234229e39a Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 16 Feb 2024 15:41:02 -0500 Subject: [PATCH 009/106] `load_gflow_net_from_run_path` returns a tuple --- scripts/dav_mp20_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index 3df1c78c9..9c0a27630 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -427,7 +427,7 @@ def plot_reward_hist(data, id_str, top_k): dave_config["rescale_outputs"] = True dave = DAVE(**dave_config) - gflownet = load_gflow_net_from_run_path(args.gflownet_path, device=dave.device) + gflownet, _ = load_gflow_net_from_run_path(args.gflownet_path, device=dave.device) loaders = make_loaders(dave.model_config) loaders["train"].dataset.ytransform = False From bd212ad2566d6eceed0a98b98d2ebd47f64e70b2 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 16 Feb 2024 15:41:43 -0500 Subject: [PATCH 010/106] move legacy code --- gflownet/utils/common.py | 17 ----------------- gflownet/utils/legacy.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 697281976..8b1bbe06d 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -55,23 +55,6 @@ def torch2np(x): return np.array(x) -def handle_logdir(): - # TODO - just copy-pasted - if "logdir" in config and config.logdir is not None: - if not Path(config.logdir).exists() or config.overwrite_logdir: - Path(config.logdir).mkdir(parents=True, exist_ok=True) - with open(config.logdir + "/config.yml", "w") as f: - yaml.dump( - numpy2python(namespace2dict(config)), f, default_flow_style=False - ) - torch.set_num_threads(1) - main(config) - else: - print(f"logdir {config.logdir} already exists! - Ending run...") - else: - print(f"working directory not defined - Ending run...") - - def download_file_if_not_exists(path: str, url: str): """ Download a file from google drive if path doestn't exist. diff --git a/gflownet/utils/legacy.py b/gflownet/utils/legacy.py index ff5226ec8..33fb12da4 100644 --- a/gflownet/utils/legacy.py +++ b/gflownet/utils/legacy.py @@ -960,3 +960,20 @@ def add_bool_arg(parser, name, default=False): group.add_argument("--no-" + name, dest=name, action="store_false") parser.set_defaults(**{name: default}) return parser + + +def handle_logdir(): + # TODO - just copy-pasted + if "logdir" in config and config.logdir is not None: + if not Path(config.logdir).exists() or config.overwrite_logdir: + Path(config.logdir).mkdir(parents=True, exist_ok=True) + with open(config.logdir + "/config.yml", "w") as f: + yaml.dump( + numpy2python(namespace2dict(config)), f, default_flow_style=False + ) + torch.set_num_threads(1) + main(config) + else: + print(f"logdir {config.logdir} already exists! - Ending run...") + else: + print(f"working directory not defined - Ending run...") From 73b1b15da1eba468abf9efa58fbb63f5fda48dc2 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 16 Feb 2024 15:41:59 -0500 Subject: [PATCH 011/106] `load_gflow_net_from_run_path` returns a tuple --- gflownet/evaluator/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index d59576141..5a782d667 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -245,7 +245,7 @@ def from_dir( GFlowNetEvaluator Instance of GFlowNetEvaluator with the GFlowNetAgent loaded from the run. """ - gfn_agent = load_gflow_net_from_run_path( + gfn_agent, _ = load_gflow_net_from_run_path( path, no_wandb=no_wandb, print_config=print_config, From cddc5d607cbae597f0ed6b4164a4164b13906e5c Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 16 Feb 2024 15:42:07 -0500 Subject: [PATCH 012/106] DOOOCSTRIIIINGS --- gflownet/utils/common.py | 320 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 313 insertions(+), 7 deletions(-) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 8b1bbe06d..8dccb6698 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -15,6 +15,30 @@ def set_device(device: Union[str, torch.device]): + """ + Get `torch` device from device. + + Examples + -------- + >>> set_device("cuda") + device(type='cuda') + + >>> set_device("cpu") + device(type='cpu') + + >>> set_device(torch.device("cuda")) + device(type='cuda') + + Parameters + ---------- + device : Union[str, torch.device] + Device. + + Returns + ------- + torch.device + `torch` device. + """ if isinstance(device, torch.device): return device if device.lower() == "cuda" and torch.cuda.is_available(): @@ -24,6 +48,32 @@ def set_device(device: Union[str, torch.device]): def set_float_precision(precision: Union[int, torch.dtype]): + """ + Get `torch` float type from precision. + + Examples + -------- + >>> set_float_precision(32) + torch.float32 + + >>> set_float_precision(torch.float32) + torch.float32 + + Parameters + ---------- + precision : Union[int, torch.dtype] + Precision. + + Returns + ------- + torch.dtype + `torch` float type. + + Raises + ------ + ValueError + If precision is not one of [16, 32, 64]. + """ if isinstance(precision, torch.dtype): return precision if precision == 16: @@ -37,6 +87,32 @@ def set_float_precision(precision: Union[int, torch.dtype]): def set_int_precision(precision: Union[int, torch.dtype]): + """ + Get `torch` integer type from `int` precision. + + Examples + -------- + >>> set_int_precision(32) + torch.int32 + + >>> set_int_precision(torch.int32) + torch.int32 + + Parameters + ---------- + precision : Union[int, torch.dtype] + Integer precision. + + Returns + ------- + torch.dtype + `torch` integer type. + + Raises + ------ + ValueError + If precision is not one of [16, 32, 64]. + """ if isinstance(precision, torch.dtype): return precision if precision == 16: @@ -50,6 +126,19 @@ def set_int_precision(precision: Union[int, torch.dtype]): def torch2np(x): + """ + Convert a torch tensor to a numpy array. + + Parameters + ---------- + x : Union[torch.Tensor, np.ndarray, list] + Data to be converted. + + Returns + ------- + np.ndarray + Converted data. + """ if hasattr(x, "is_cuda") and x.is_cuda: x = x.detach().cpu() return np.array(x) @@ -74,10 +163,54 @@ def download_file_if_not_exists(path: str, url: str): def resolve_path(path: str) -> Path: + """ + Resolve a path by expanding environment variables, user home directory, and making + it absolute. + + Examples + -------- + >>> resolve_path("~/scratch/$SLURM_JOB_ID/data") + Path("/home/user/scratch/12345/data") + + Parameters + ---------- + path : Union[str, Path] + Path to be resolved. + + Returns + ------- + Path + Resolved path. + """ return Path(expandvars(str(path))).expanduser().resolve() def find_latest_checkpoint(ckpt_dir, ckpt_name): + """ + Find the latest checkpoint in the directory with the specified name. + + If the checkpoint name contains the string "final", that checkpoint is returned. + Otherwise, the latest checkpoint is returned based on the iteration number. + + Parameters + ---------- + ckpt_dir : Union[str, Path] + Directory in which to search for the checkpoints. + ckpt_name : str + Name of the checkpoint. Typically, this is the name of forward or backward + policy. + + Returns + ------- + Path + Path to the latest checkpoint. + + Raises + ------ + ValueError + If no final checkpoint is found and no other checkpoints are found according to + the specified pattern: `{ckpt_name}*`. + """ ckpt_name = Path(ckpt_name).stem final = list(ckpt_dir.glob(f"{ckpt_name}*final*")) if len(final) > 0: @@ -97,6 +230,27 @@ def load_gflow_net_from_run_path( device="cuda", load_final_ckpt=True, ): + """ + Load GFlowNet from a run path (directory with a `.hydra` directory inside). + + Parameters + ---------- + run_path : Union[str, Path] + Path to the run directory. Must contain a `.hydra` directory. + no_wandb : bool, optional + Whether to disable wandb in the GFN init, by default True. + print_config : bool, optional + Whether to print the loaded config, by default False. + device : str, optional + Device to which the models should be moved, by default "cuda". + load_final_ckpt : bool, optional + Whether to load the final models, by default True. + + Returns + ------- + Tuple[GFN, DictConfig] + Loaded GFlowNet and the loaded config. + """ run_path = resolve_path(run_path) hydra_dir = run_path / ".hydra" @@ -176,6 +330,27 @@ def load_gflow_net_from_run_path( def batch_with_rest(start, stop, step, tensor=False): + """ + Yields batches of indices from start to stop with step size. The last batch may be + smaller than step. + + Parameters + ---------- + start : int + Start index + stop : int + End index (exclusive) + step : int + Step size + tensor : bool, optional + Whether to return a `torch` tensor of indices instead of a `numpy` array, by + default False. + + Yields + ------ + Union[np.ndarray, torch.Tensor] + Batch of indices + """ for i in range(start, stop, step): if tensor: yield torch.arange(i, min(i + step, stop)) @@ -184,6 +359,27 @@ def batch_with_rest(start, stop, step, tensor=False): def tfloat(x, device, float_type): + """ + Convert input to a float tensor. If the input is a list of tensors, the tensors + are stacked along the first dimension. + + The resulting tensor is moved to the specified device. + + Parameters + ---------- + x : Union[List[torch.Tensor], torch.Tensor, List[Union[int, float]], Union[int, + float]] + Input to be converted to a float tensor. + device : torch.device + Device to which the tensor should be moved. + float_type : torch.dtype + Float type to which the tensor should be converted. + + Returns + ------- + Union[torch.Tensor, List[torch.Tensor]] + Float tensor. + """ if isinstance(x, list) and torch.is_tensor(x[0]): return torch.stack(x).to(device=device, dtype=float_type) if torch.is_tensor(x): @@ -193,6 +389,25 @@ def tfloat(x, device, float_type): def tlong(x, device): + """ + Convert input to a long tensor. If the input is a list of tensors, the tensors + are stacked along the first dimension. + + The resulting tensor is moved to the specified device. + + Parameters + ---------- + x : Union[List[torch.Tensor], torch.Tensor, List[Union[int, float]], Union[int, + float]] + Input to be converted to a long tensor. + device : torch.device + Device to which the tensor should be moved. + + Returns + ------- + Union[torch.Tensor, List[torch.Tensor]] + Long tensor. + """ if isinstance(x, list) and torch.is_tensor(x[0]): return torch.stack(x).to(device=device, dtype=torch.long) if torch.is_tensor(x): @@ -202,6 +417,27 @@ def tlong(x, device): def tint(x, device, int_type): + """ + Convert input to an integer tensor. If the input is a list of tensors, the tensors + are stacked along the first dimension. + + The resulting tensor is moved to the specified device. + + Parameters + ---------- + x : Union[List[torch.Tensor], torch.Tensor, List[Union[int, float]], Union[int, + float]] + Input to be converted to an integer tensor. + device : torch.device + Device to which the tensor should be moved. + int_type : torch.dtype + Integer type to which the tensor should be converted. + + Returns + ------- + Union[torch.Tensor, List[torch.Tensor]] + Integer tensor. + """ if isinstance(x, list) and torch.is_tensor(x[0]): return torch.stack(x).to(device=device, dtype=int_type) if torch.is_tensor(x): @@ -211,6 +447,25 @@ def tint(x, device, int_type): def tbool(x, device): + """ + Convert input to a boolean tensor. If the input is a list of tensors, the tensors + are stacked along the first dimension. + + The resulting tensor is moved to the specified device. + + Parameters + ---------- + x : Union[List[torch.Tensor], torch.Tensor, List[Union[int, float]], Union[int, + float]] + Input to be converted to a boolean tensor. + device : torch.device + Device to which the tensor should be moved. + + Returns + ------- + Union[torch.Tensor, List[torch.Tensor]] + Boolean tensor. + """ if isinstance(x, list) and torch.is_tensor(x[0]): return torch.stack(x).to(device=device, dtype=torch.bool) if torch.is_tensor(x): @@ -219,16 +474,39 @@ def tbool(x, device): return torch.tensor(x, dtype=torch.bool, device=device) -def concat_items(list_of_items, index=None): +def concat_items(list_of_items, indices=None): + """ + Concatenates a list of items into a single tensor or array. + + Parameters + ---------- + list_of_items : + List of items to be concatenated, i.e. list of arrays or list of tensors. + indices : Union[List[np.ndarray], List[torch.Tensor]], optional + Indices to select in the resulting concatenated tensor or array, by default + None. + + Returns + ------- + Union[np.ndarray, torch.Tensor] + Concatenated tensor or array, with optional selection of indices. + + Raises + ------ + NotImplementedError + If the input type is not supported, i.e., not a list of arrays or a list of + tensors. + """ if isinstance(list_of_items[0], np.ndarray): result = np.concatenate(list_of_items) - if index is not None: - index = index.cpu().numpy() - result = result[index] + if indices is not None: + if torch.is_tensor(indices[0]): + indices = indices.cpu().numpy() + result = result[indices] elif torch.is_tensor(list_of_items[0]): result = torch.cat(list_of_items) - if index is not None: - result = result[index] + if indices is not None: + result = result[indices] else: raise NotImplementedError( "cannot concatenate {}".format(type(list_of_items[0])) @@ -240,7 +518,20 @@ def concat_items(list_of_items, index=None): def extend( orig: Union[List, TensorType["..."]], new: Union[List, TensorType["..."]] ) -> Union[List, TensorType["..."]]: - assert type(orig) == type(new) + """ + Extends the original list or tensor with the new list or tensor. + + Returns + ------- + Union[List, TensorType["..."]] + Extended list or tensor. + + Raises + ------ + NotImplementedError + If the input type is not supported, i.e., not a list or a tensor. + """ + assert isinstance(orig, type(new)) if isinstance(orig, list): orig.extend(new) elif torch.tensor(orig): @@ -253,6 +544,21 @@ def extend( def copy(x: Union[List, TensorType["..."]]): + """ + Makes copy of the input tensor or list. + + A tensor is cloned and detached from the computational graph. + + Parameters + ---------- + x : Union[List, TensorType["..."]] + Input tensor or list to be copied. + + Returns + ------- + Union[List, TensorType["..."]] + Copy of the input tensor or list. + """ if torch.is_tensor(x): return x.clone().detach() else: From 54f2f23b6d8c962a0e6b60d6e06876e14d0040c9 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Sun, 18 Feb 2024 11:44:52 -0500 Subject: [PATCH 013/106] `@classmethod` --- gflownet/evaluator/base.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 5a782d667..0686bf658 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -216,8 +216,9 @@ def do_checkpoints(self, step): else: return not step % self.checkpoints.period - @staticmethod + @classmethod def from_dir( + cls: "GFlowNetEvaluator", path: Union[str, os.PathLike], no_wandb: bool = True, print_config: bool = False, @@ -229,6 +230,8 @@ def from_dir( Parameters ---------- + cls : GFlowNetEvaluator + Class to instantiate. path : Union[str, os.PathLike] Path to the run directory from which to load the GFlowNetAgent. no_wandb : bool, optional @@ -254,13 +257,15 @@ def from_dir( ) return GFlowNetEvaluator.from_agent(gfn_agent) - @staticmethod - def from_agent(gfn_agent): + @classmethod + def from_agent(cls, gfn_agent): """ Instantiate a GFlowNetEvaluator from a GFlowNetAgent. Parameters ---------- + cls : GFlowNetEvaluator + Evaluator class to instantiate. gfn_agent : GFlowNetAgent Instance of GFlowNetAgent to use for the GFlowNetEvaluator. From f546506e5550cec4a8997ffe4a633c55ce332278 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 10:13:00 -0500 Subject: [PATCH 014/106] unused `log_iter` --- gflownet/gflownet.py | 41 ----------------------------------------- 1 file changed, 41 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index a4eb0ab9c..f1d408eff 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1135,47 +1135,6 @@ def get_log_corr(self, times): corr = np.corrcoef(data_logq, self.buffer.test["energies"]) return corr, data_logq, times - # TODO: reorganize and remove - def log_iter( - self, - pbar, - rewards, - proxy_vals, - states_term, - data, - it, - times, - losses, - all_losses, - all_visited, - ): - # train metrics - self.logger.log_sampler_train( - rewards, proxy_vals, states_term, data, it, self.use_context - ) - - # logZ - self.logger.log_metric("logZ", self.logZ.sum(), it, use_context=False) - - # test metrics - # TODO: integrate corr into test() - if not self.logger.lightweight and self.buffer.test is not None: - corr, data_logq, times = self.get_log_corr(times) - self.logger.log_sampler_test(corr, data_logq, it, self.use_context) - - # oracle metrics - oracle_batch, oracle_times = self.sample_batch( - n_forward=self.oracle_n, train=False - ) - - if not self.logger.lightweight: - self.logger.log_metric( - "unique_states", - np.unique(all_visited).shape[0], - step=it, - use_context=self.use_context, - ) - def make_opt(params, logZ, config): """ From 1225034153eba0be4121ee765addf304998f90f2 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 10:44:32 -0500 Subject: [PATCH 015/106] GFNA init docstring --- gflownet/gflownet.py | 68 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index f1d408eff..91981480e 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -48,14 +48,80 @@ def __init__( logger, num_empirical_loss, oracle, + eval_config, state_flow=None, active_learning=False, sample_only=False, replay_sampling="permutation", train_sampling="permutation", - eval_config=None, **kwargs, ): + """ + Main class of this repository. Handles + + Parameters + ---------- + env : GFlowNetEnv + The environment to be used for training, i.e. the DAG, action space and + reward function. + seed : int + Random seed to be used for reproducibility. + device : str + Device to be used for training and inference, e.g. "cuda" or "cpu". + float_precision : int + Precision of the floating point numbers, e.g. 32 or 64. + optimizer : dict + Optimizer config dictionary. See gflownet.yaml:optimizer for details. + buffer : dict + Buffer config dictionary. See gflownet.yaml:buffer for details. + forward_policy : gflownet.policy.base.Policy + The forward policy to be used for training. Parameterized from + `gflownet.yaml:forward_policy` and parsed with + `gflownet/utils/policy.py:set_policy`. + backward_policy : gflownet.policy.base.Policy + Same as forward_policy, but for the backward policy. + mask_invalid_actions : bool + Whether to mask invalid actions in the policy outputs. + temperature_logits : float + Temperature to adjust the logits by logits /= temperature. If None, + self.temperature_logits is used. + random_action_prob : float + Probability of sampling random actions. If None (default), + self.random_action_prob is used, unless its value is forced to either 0.0 or + 1.0 by other arguments (sampling_method or no_random). + pct_offline : float + Percentage of offline data to be used for training. + logger : gflownet.utils.logger.Logger + Logger object to be used for logging and saving checkpoints + (`gflownet/utils/logger.py:Logger`). + num_empirical_loss : int + Number of empirical loss samples to be used for training. + oracle : dict + Oracle config dictionary. See gflownet.yaml:oracle for details. + eval_config : dict, optional + Evaluator config dictionary. See `eval/base.yaml` for details. By default + None. + state_flow : dict, optional + State flow config dictionary. See `gflownet.yaml:state_flow` for details. By + default None. + active_learning : bool, optional + Whether this GFlowNetAgent is part of an active learning loop, by default + False. This means the logger will use its context in metrics names. + sample_only : bool, optional + This GFNA is only going to be used to sample, no need to make the train/test + buffer. + replay_sampling : str, optional + Type of sampling for the replay buffer. See + :method:`~gflownet.utils.buffer.select`. By default "permutation". + train_sampling : str, optional + Type of sampling for the train buffer (offline backward trajectories). See + :method:`~gflownet.utils.buffer.select`. By default "permutation". + + Raises + ------ + Exception + If the loss is flowmatch/flowmatching and the environment is continuous. + """ # Seed self.rng = np.random.default_rng(seed) # Device From f7662d3406b2d65e993686c3b9a66900aae24676 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 10:46:24 -0500 Subject: [PATCH 016/106] refactor `should_` `train/eval/checkpoint`etc. --- gflownet/evaluator/base.py | 6 ++--- gflownet/gflownet.py | 36 +++++++++++++++------------- gflownet/utils/logger.py | 49 ++++++++++++++++---------------------- 3 files changed, 43 insertions(+), 48 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 0686bf658..acb59643e 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -116,7 +116,7 @@ def do_train(self, step): else: return not step % self.config.train.period - def do_test(self, step): + def should_eval(self, step): """ Check if testing should be done at the current step. The decision is based on the `self.config.test.period` attribute. @@ -144,7 +144,7 @@ def do_test(self, step): else: return not step % self.config.test.period - def do_top_k(self, step): + def should_eval_top_k(self, step): """ Check if top k plots and metrics should be done at the current step. The decision is based on the `self.config.test.top_k` and @@ -193,7 +193,7 @@ def do_oracle(self, step): else: return not step % self.oracle.period - def do_checkpoints(self, step): + def should_checkpoint(self, step): """ Check if checkpoints should be done at the current step. The decision is based on the `self.checkpoints.period` attribute. diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 91981480e..d71336100 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1045,9 +1045,9 @@ def train(self): pbar = tqdm(range(1, self.n_train_steps + 1), disable=not self.logger.progress) for it in pbar: # Test and log - if self.evaluator.do_test(it): + if self.evaluator.should_eval(it): self.evaluator.eval_and_log(it) - if self.evaluator.do_top_k(it): + if self.evaluator.should_eval_top_k(it): self.evaluator.eval_and_log_top_k(it) t0_iter = time.time() @@ -1111,7 +1111,7 @@ def train(self): # Log if self.logger.lightweight: all_losses = all_losses[-100:] - all_visited = states_term + all_visited = states_term # TODO-V: unused else: all_visited.extend(states_term) # Progress bar @@ -1120,24 +1120,26 @@ def train(self): ) # Train logs t0_log = time.time() - self.logger.log_train( - losses=losses, - rewards=rewards, - proxy_vals=proxy_vals, - states_term=states_term, - batch_size=len(batch), - logz=self.logZ, - learning_rates=self.lr_scheduler.get_last_lr(), - step=it, - use_context=self.use_context, - ) + if self.evaluator.should_log_train(it): + self.logger.log_train( + losses=losses, + rewards=rewards, + proxy_vals=proxy_vals, + states_term=states_term, + batch_size=len(batch), + logz=self.logZ, + learning_rates=self.lr_scheduler.get_last_lr(), + step=it, + use_context=self.use_context, + ) t1_log = time.time() times.update({"log": t1_log - t0_log}) # Save intermediate models t0_model = time.time() - self.logger.save_models( - self.forward_policy, self.backward_policy, self.state_flow, step=it - ) + if self.evaluator.should_checkpoint(it): + self.logger.save_models( + self.forward_policy, self.backward_policy, self.state_flow, step=it + ) t1_model = time.time() times.update({"save_interim_model": t1_model - t0_model}) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 3ca8cfd46..8d221aa04 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -23,8 +23,6 @@ def __init__( do: dict, project_name: str, logdir: dict, - train: dict, - test: dict, oracle: dict, checkpoints: dict, progress: bool, @@ -38,8 +36,6 @@ def __init__( self.config = config self.do = do self.do.times = self.do.times and self.do.online - self.train = train - self.test = test self.oracle = oracle self.checkpoints = checkpoints slurm_job_id = os.environ.get("SLURM_JOB_ID") @@ -195,8 +191,6 @@ def log_train( step: int, use_context: bool, ): - if not self.do.online or not self.do_train(step): - return if logz is None: logz = 0.0 else: @@ -255,7 +249,7 @@ def log_sampler_test( ): if not self.do.online: return - if self.do_test(step): + if self.should_eval(step): test_metrics = dict( zip( [ @@ -355,28 +349,27 @@ def log_test_metrics( def save_models( self, forward_policy, backward_policy, state_flow, step: int = 1e9, final=False ): - if self.do_checkpoints(step) or final: - if final: - ckpt_id = "final" - else: - ckpt_id = "_iter{:06d}".format(step) - if forward_policy.is_model and self.pf_ckpt_path is not None: - stem = self.pf_ckpt_path.stem + self.context + ckpt_id + ".ckpt" - path = self.pf_ckpt_path.parent / stem - torch.save(forward_policy.model.state_dict(), path) - if ( - backward_policy - and backward_policy.is_model - and self.pb_ckpt_path is not None - ): - stem = self.pb_ckpt_path.stem + self.context + ckpt_id + ".ckpt" - path = self.pb_ckpt_path.parent / stem - torch.save(backward_policy.model.state_dict(), path) + if final: + ckpt_id = "final" + else: + ckpt_id = "_iter{:06d}".format(step) + if forward_policy.is_model and self.pf_ckpt_path is not None: + stem = self.pf_ckpt_path.stem + self.context + ckpt_id + ".ckpt" + path = self.pf_ckpt_path.parent / stem + torch.save(forward_policy.model.state_dict(), path) + if ( + backward_policy + and backward_policy.is_model + and self.pb_ckpt_path is not None + ): + stem = self.pb_ckpt_path.stem + self.context + ckpt_id + ".ckpt" + path = self.pb_ckpt_path.parent / stem + torch.save(backward_policy.model.state_dict(), path) - if state_flow is not None and self.sf_ckpt_path is not None: - stem = self.sf_ckpt_path.stem + self.context + ckpt_id + ".ckpt" - path = self.sf_ckpt_path.parent / stem - torch.save(state_flow.model.state_dict(), path) + if state_flow is not None and self.sf_ckpt_path is not None: + stem = self.sf_ckpt_path.stem + self.context + ckpt_id + ".ckpt" + path = self.sf_ckpt_path.parent / stem + torch.save(state_flow.model.state_dict(), path) def log_time(self, times: dict, use_context: bool): if self.do.times: From 924dc189910fbc656f58b1348c80ad2d81a2d236 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 11:55:14 -0500 Subject: [PATCH 017/106] don't log `None` values --- gflownet/utils/logger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 8d221aa04..e272d8b09 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -172,7 +172,8 @@ def log_metrics(self, metrics: dict, step: int, use_context: bool = True): if not self.do.online: return for key, value in metrics.items(): - self.log_metric(key, value, step=step, use_context=use_context) + if value is not None: + self.log_metric(key, value, step=step, use_context=use_context) def log_summary(self, summary: dict): if not self.do.online: From 274ff79c9655ff4709e2d851251bd0b69b88d096 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 11:56:01 -0500 Subject: [PATCH 018/106] No need for a dedicated `log_test_metrics` --- gflownet/evaluator/base.py | 4 +++- gflownet/utils/logger.py | 48 -------------------------------------- 2 files changed, 3 insertions(+), 49 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index acb59643e..804aabfbe 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -601,7 +601,9 @@ def eval_and_log(self, it, metrics=_sentinel): for m, v in result["metrics"].items(): setattr(gfn, m, v) - self.logger.log_test_metrics(*result["metrics"].values(), it, gfn.use_context) + mertics_to_log = {METRICS[k]["name"]: v for k, v in result["metrics"].values()} + + self.logger.log_metrics(mertics_to_log, it, gfn.use_context) self.logger.log_metrics(result["env_metrics"], it, use_context=gfn.use_context) self.logger.log_plots(result["figs"], it, use_context=gfn.use_context) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index e272d8b09..9c41dd037 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -299,54 +299,6 @@ def log_losses( use_context=use_context, ) - def log_test_metrics( - self, - l1: float, - kl: float, - jsd: float, - corr_prob_traj_rewards: float, - var_logrewards_logp: float, - nll_tt: float, - mean_logprobs_std: float, - mean_probs_std: float, - logprobs_std_nll_ratio: float, - step: int, - use_context: bool, - ): - if not self.do.online: - return - metrics = dict( - zip( - [ - "L1 error", - "KL Div.", - "Jensen Shannon Div.", - "Corr. (test probs., rewards)", - "Var(logR - logp) test", - "NLL of test data", - "Mean BS Std(logp)", - "Mean BS Std(p)", - "BS Std(logp) / NLL", - ], - [ - l1, - kl, - jsd, - corr_prob_traj_rewards, - var_logrewards_logp, - nll_tt, - mean_logprobs_std, - mean_probs_std, - logprobs_std_nll_ratio, - ], - ) - ) - self.log_metrics( - metrics, - use_context=use_context, - step=step, - ) - def save_models( self, forward_policy, backward_policy, state_flow, step: int = 1e9, final=False ): From f75c8762f7489d40b945970201c3137d9e58277a Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 11:57:24 -0500 Subject: [PATCH 019/106] move figs to `plot(...)` --- gflownet/evaluator/base.py | 74 +++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 21 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 804aabfbe..8259c872a 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -283,12 +283,59 @@ def from_agent(cls, gfn_agent): return GFlowNetEvaluator(gfn_agent=gfn_agent, sentinel=_sentinel) - def plot(self): + def plot(self, x_sampled=None, kde_pred=None, kde_true=None, **plot_kwargs): """ - Plots this evaluator should do. This is a base method that does nothing and - should be overridden by subclasses. + Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which + will be logged. + + By default, this method will call the `plot_reward_samples` method of the + GFlowNetAgent's environment, and the `plot_kde` method of the GFlowNetAgent's + environment if it exists for both the `kde_pred` and `kde_true` arguments. + + Extend this method to add more plots: + + ```python + def plot(self, x_sampled, kde_pred, kde_true, **plot_kwargs): + figs = super().plot(x_sampled, kde_pred, kde_true, **plot_kwargs) figs["My + custom plot"] = my_custom_plot_function(x_sampled, kde_pred) return figs + ``` + + Parameters + ---------- + x_sampled : list, optional + List of sampled states. + kde_pred : sklearn.neighbors.KernelDensity + KDE policy as per `Environment.fit_kde` + kde_true : object + True KDE. + plot_kwargs : dict + Additional keyword arguments to pass to the plotting methods. + + Returns + ------- + dict[str, plt.Figure] + Dictionary of figures to be logged. The keys are the figure names and the + values are the figures. """ - print("Base evaluator `plot()` method does not do anything.") + gfn = self.gfn_agent + + fig_kde_pred = fig_kde_true = fig_reward_samples = None + + if hasattr(gfn.env, "plot_reward_samples") and x_sampled is not None: + fig_reward_samples = gfn.env.plot_reward_samples(x_sampled, **plot_kwargs) + + if hasattr(gfn.env, "plot_kde"): + if kde_pred is not None: + fig_kde_pred = gfn.env.plot_kde(kde_pred, **plot_kwargs) + if kde_true is not None: + fig_kde_true = gfn.env.plot_kde(kde_true, **plot_kwargs) + + return { + "True reward and GFlowNet samples": fig_reward_samples, + "GFlowNet KDE Policy": fig_kde_pred, + "Reward KDE": fig_kde_true, + } + def eval(self, metrics=_sentinel, **plot_kwargs): """ @@ -443,18 +490,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): jsd = 0.5 * np.sum(density_true * (log_density_true - log_mean_dens)) jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) - # Plots - # TODO-V: move to evaluator.plot()? - if hasattr(gfn.env, "plot_reward_samples"): - fig_reward_samples = gfn.env.plot_reward_samples(x_sampled, **plot_kwargs) - else: - fig_reward_samples = None - if hasattr(gfn.env, "plot_kde"): - fig_kde_pred = gfn.env.plot_kde(kde_pred, **plot_kwargs) - fig_kde_true = gfn.env.plot_kde(kde_true, **plot_kwargs) - else: - fig_kde_pred = None - fig_kde_true = None + figs = self.plot(x_sampled=x_sampled, kde_pred=kde_pred, kde_true=kde_true) return { "metrics": { @@ -468,11 +504,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): "mean_probs_std": mean_probs_std, "logprobs_std_nll_ratio": logprobs_std_nll_ratio, }, - "figs": { - "True reward and GFlowNet samples": fig_reward_samples, - "GFlowNet KDE Policy": fig_kde_pred, - "Reward KDE": fig_kde_true, - }, + "figs": figs, "env_metrics": {}, } From 1fdb4013475db03619821b18b5efbd7589a042d5 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 11:59:02 -0500 Subject: [PATCH 020/106] setup `requires` system --- gflownet/evaluator/base.py | 69 ++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 8259c872a..5e210712c 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -19,16 +19,43 @@ _sentinel = object() -METRICS_NAMES = { - "l1": "L1 error", - "kl": "KL Div.", - "jsd": "Jensen Shannon Div.", - "corr_prob_traj_rewards": "Corr. (test probs., rewards)", - "var_logrewards_logp": "Var(logR - logp) test", - "nll_tt": "NLL of test data", - "mean_logprobs_std": "Mean BS Std(logp)", - "mean_probs_std": "Mean BS Std(p)", - "logprobs_std_nll_ratio": "BS Std(logp) / NLL", +METRICS = { + "l1": { + "name": "L1 error", + "requires": ["density"], + }, + "kl": { + "name": "KL Div.", + "requires": ["density"], + }, + "jsd": { + "name": "Jensen Shannon Div.", + "requires": ["density"], + }, + "corr_prob_traj_rewards": { + "name": "Corr. (test probs., rewards)", + "requires": ["log_probs", "reward_batch"], + }, + "var_logrewards_logp": { + "name": "Var(logR - logp) test", + "requires": ["log_probs", "reward_batch"], + }, + "nll_tt": { + "name": "NLL of test data", + "requires": ["log_probs"], + }, + "mean_logprobs_std": { + "name": "Mean BS Std(logp)", + "requires": ["log_probs"], + }, + "mean_probs_std": { + "name": "Mean BS Std(p)", + "requires": ["log_probs"], + }, + "logprobs_std_nll_ratio": { + "name": "BS Std(logp) / NLL", + "requires": ["log_probs"], + }, } @@ -59,6 +86,7 @@ def __init__(self, **kwargs): self.gfn_agent = kwargs.get("gfn_agent") self.config = self.gfn_agent.eval_config self.logger = self.gfn_agent.logger + self.requires = set() self.set_metrics(self.config.metrics) @@ -67,31 +95,38 @@ def set_metrics(self, metrics=None): Set the metrics to be computed by the evaluator to the `self.metrics` attribute. If `None`, all metrics are computed. If a string, it can be a comma-separated - list of metric names, with or without spaces. All metrics must be in - `METRICS_NAMES`. + list of metric names, with or without spaces. All metrics must be in `METRICS`. + + Sets the `self.metrics` attribute to a dictionary of metrics to be computed + according to the `METRICS` dictionary. In other words, `self.metrics` will be + a subset of `METRICS`. Parameters ---------- metrics : (Union[str, List[str]], optional) Metrics to compute when running the `evaluator.eval()` function. Defaults to - None, i.e. all metrics in `METRICS_NAMES` are computed. + None, i.e. all metrics in `METRICS` are computed. Raises ------ ValueError - If a metric name is not in `METRICS_NAMES`. + If a metric name is not in `METRICS`. """ if metrics is None: - metrics = METRICS_NAMES.keys() + metrics = list(METRICS.keys()) if isinstance(metrics, str): if "," in metrics: metrics = [m.strip() for m in metrics.split(",")] else: metrics = [metrics] for m in metrics: - if m not in METRICS_NAMES: + if m not in METRICS: raise ValueError(f"Unknown metric name: {m}") - self.metrics = metrics + + self.metrics = {k: METRICS[k] for k in metrics} + self.requires = set( + [r for m in self.metrics for r in self.metrics[m]["requires"]] + ) def do_train(self, step): """ From baf701a44b84ccd36ef760ddd19afa6a3354cf91 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 12:00:09 -0500 Subject: [PATCH 021/106] allow for custom `require` --- gflownet/evaluator/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 5e210712c..69ed9e6d9 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -397,7 +397,12 @@ def eval(self, metrics=_sentinel, **plot_kwargs): if metrics is None: # TODO-V use this in the rest of the code to selectively compute metrics - metrics = set(METRICS_NAMES.keys()) + metrics = set(METRICS.keys()) + requires = set([r for m in metrics for r in METRICS[m]["requires"]]) + + if metrics is _sentinel: + metrics = self.metrics + requires = self.requires if gfn.buffer.test_pkl is None: result = { From 00fc1b5cc935cb73b2b53a2440463e3c04549e23 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 12:00:22 -0500 Subject: [PATCH 022/106] typo returned dict --- gflownet/evaluator/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 69ed9e6d9..4fd14263a 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -412,7 +412,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): } result["figs"] = {} result["env_metrics"] = {} - return result.values() + return result with open(gfn.buffer.test_pkl, "rb") as f: dict_tt = pickle.load(f) From c1a5dc81816d5151db78cacf9574672a5433d7ff Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 12:00:51 -0500 Subject: [PATCH 023/106] move log prob metrcis to `compute_log_prob_metrics(...)` --- gflownet/evaluator/base.py | 91 ++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 32 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 4fd14263a..d6c44f3e2 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -371,6 +371,56 @@ def plot(self, x_sampled, kde_pred, kde_true, **plot_kwargs): "Reward KDE": fig_kde_true, } + def compute_log_prob_metrics(self, x_tt, gfn, dict_tt, requires=_sentinel): + gfn = self.gfn_agent + + if requires is None: + requires = set([r for m in self.metrics.values() for r in m["requires"]]) + if requires is _sentinel: + requires = self.requires + + logprobs_x_tt, logprobs_std, probs_std = gfn.estimate_logprobs_data( + x_tt, + n_trajectories=self.logger.test.n_trajs_logprobs, + max_data_size=self.logger.test.max_data_logprobs, + batch_size=self.logger.test.logprobs_batch_size, + bs_num_samples=self.logger.test.logprobs_bootstrap_size, + ) + + lp_metrics = {} + + if "mean_logprobs_std" in self.metrics: + lp_metrics["mean_logprobs_std"] = logprobs_std.mean().item() + + if "mean_probs_std" in self.metrics: + lp_metrics["mean_probs_std"] = probs_std.mean().item() + + if "reward_batch" in requires: + rewards_x_tt = gfn.env.reward_batch(x_tt) + + if "corr_prob_traj_rewards" in self.metrics: + rewards_x_tt = gfn.env.reward_batch(x_tt) + lp_metrics["corr_prob_traj_rewards"] = np.corrcoef( + np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt + )[0, 1] + + if "var_logrewards_logp" in self.metrics: + rewards_x_tt = gfn.env.reward_batch(x_tt) + lp_metrics["var_logrewards_logp"] = torch.var( + torch.log( + tfloat(rewards_x_tt, float_type=gfn.float, device=gfn.device) + ) + - logprobs_x_tt + ).item() + if "nll_tt" in self.metrics: + lp_metrics["nll_tt"] = -logprobs_x_tt.mean().item() + + if "logprobs_std_nll_ratio" in self.metrics: + lp_metrics["logprobs_std_nll_ratio"] = ( + -logprobs_std.mean() / logprobs_x_tt.mean() + ) + + return lp_metrics def eval(self, metrics=_sentinel, **plot_kwargs): """ @@ -381,7 +431,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): Parameters ---------- - metrics : _type_, optional + metrics : List[str], optional List of metrics to compute, by default the evaluator's `metrics` attribute. plot_kwargs : dict, optional Additional keyword arguments to pass to the plotting methods. @@ -394,6 +444,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): logprobs_std_nll_ratio, figs, env_metrics] (should be refactored to dict) """ gfn = self.gfn_agent + all_metrics = {} if metrics is None: # TODO-V use this in the rest of the code to selectively compute metrics @@ -421,25 +472,11 @@ def eval(self, metrics=_sentinel, **plot_kwargs): # Compute correlation between the rewards of the test data and the log # likelihood of the data according the the GFlowNet policy; and NLL. # TODO: organise code for better efficiency and readability - logprobs_x_tt, logprobs_std, probs_std = gfn.estimate_logprobs_data( - x_tt, - n_trajectories=self.logger.test.n_trajs_logprobs, - max_data_size=self.logger.test.max_data_logprobs, - batch_size=self.logger.test.logprobs_batch_size, - bs_num_samples=self.logger.test.logprobs_bootstrap_size, - ) - mean_logprobs_std = logprobs_std.mean().item() - mean_probs_std = probs_std.mean().item() - rewards_x_tt = gfn.env.reward_batch(x_tt) - corr_prob_traj_rewards = np.corrcoef( - np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt - )[0, 1] - var_logrewards_logp = torch.var( - torch.log(tfloat(rewards_x_tt, float_type=gfn.float, device=gfn.device)) - - logprobs_x_tt - ).item() - nll_tt = -logprobs_x_tt.mean().item() - logprobs_std_nll_ratio = torch.mean(-logprobs_std / logprobs_x_tt).item() + if "log_probs" in requires: + lp_metrics = self.compute_log_prob_metrics( + x_tt, gfn, dict_tt, requires=requires + ) + all_metrics.update(lp_metrics) x_sampled = [] if gfn.buffer.test_type is not None and gfn.buffer.test_type == "all": @@ -511,12 +548,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): "l1": gfn.l1, "kl": gfn.kl, "jsd": gfn.jsd, - "corr_prob_traj_rewards": corr_prob_traj_rewards, - "var_logrewards_logp": var_logrewards_logp, - "nll_tt": nll_tt, - "mean_logprobs_std": mean_logprobs_std, - "mean_probs_std": mean_probs_std, - "logprobs_std_nll_ratio": logprobs_std_nll_ratio, + **all_metrics, }, "figs": {}, "env_metrics": env_metrics, @@ -537,12 +569,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): "l1": l1, "kl": kl, "jsd": jsd, - "corr_prob_traj_rewards": corr_prob_traj_rewards, - "var_logrewards_logp": var_logrewards_logp, - "nll_tt": nll_tt, - "mean_logprobs_std": mean_logprobs_std, - "mean_probs_std": mean_probs_std, - "logprobs_std_nll_ratio": logprobs_std_nll_ratio, + **all_metrics, }, "figs": figs, "env_metrics": {}, From 2f08b0d84b35b89965bcf99c39f8d0068e045c7b Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 12:31:37 -0500 Subject: [PATCH 024/106] improve `make_metrics` and `make_requires` --- gflownet/evaluator/base.py | 132 ++++++++++++++++++++++++++----------- 1 file changed, 92 insertions(+), 40 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index d6c44f3e2..1b86c6b14 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -88,46 +88,108 @@ def __init__(self, **kwargs): self.logger = self.gfn_agent.logger self.requires = set() - self.set_metrics(self.config.metrics) + self.metrics = self.requires = _sentinel + self.metrics = self.make_metrics(self.config.metrics) + self.requires = self.make_require() - def set_metrics(self, metrics=None): + def make_metrics(self, metrics=None): """ - Set the metrics to be computed by the evaluator to the `self.metrics` attribute. + Parse metrics from a list, a string or None. If `None`, all metrics are computed. If a string, it can be a comma-separated list of metric names, with or without spaces. All metrics must be in `METRICS`. - Sets the `self.metrics` attribute to a dictionary of metrics to be computed - according to the `METRICS` dictionary. In other words, `self.metrics` will be - a subset of `METRICS`. - Parameters ---------- metrics : (Union[str, List[str]], optional) Metrics to compute when running the `evaluator.eval()` function. Defaults to None, i.e. all metrics in `METRICS` are computed. + Returns + ------- + dict + Dictionary of metrics to compute, with the metric names as keys and the + metric names and requires as values. + Raises ------ ValueError If a metric name is not in `METRICS`. """ + if metrics == "all": + metrics = METRICS.keys() + if metrics is None: - metrics = list(METRICS.keys()) + assert self.metrics is not _sentinel, ( + "Error setting self.metrics. This is likely due to the `metrics:`" + + " entry missing from your eval config. Set it to 'all' to compute all" + + " metrics or to a comma-separated list of metric names (eg 'l1, kl')." + ) + return self.metrics + if isinstance(metrics, str): if "," in metrics: metrics = [m.strip() for m in metrics.split(",")] else: metrics = [metrics] + for m in metrics: if m not in METRICS: raise ValueError(f"Unknown metric name: {m}") - self.metrics = {k: METRICS[k] for k in metrics} - self.requires = set( - [r for m in self.metrics for r in self.metrics[m]["requires"]] + return {m: METRICS[m] for m in metrics} + + def make_requires(self, requires=None, metrics=None): + """ + Make requirements for the metrics to compute. + + 1. If `metrics` is provided, they must be as a dict of metrics. The requirements + are computed from the `requires` attribute of the metrics. + + 2. Otherwise, the requirements are computed from the `requires` argument: + - If `requires` is `"all"`, all requirements of all metrics are computed. + - If `requires` is `None`, the evaluator's `self.requires` attribute is + used. + - If `requires` is a list, it is used as the requirements. + + Parameters + ---------- + requires : Union[str, List[str]], optional + The metrics requirements. Either `"all"`, a list of requirements or `None` + to use the evaluator's `self.requires` attribute. By default None + metrics : List[str], optional + The list of metrics dicts to compute requirements for. By default None. + + Returns + ------- + set[str] + The set of requirements for the metrics. + """ + + if metrics is not None: + return set([r for m in metrics.values() for r in m["requires"]]) + + if requires == "all": + requires = set([r for m in METRICS.values() for r in m["requires"]]) + if requires is None: + if self.requires is _sentinel: + self.requires = set( + [r for m in self.metrics.values() for r in m["requires"]] + ) + requires = self.requires + if isinstance(requires, list): + requires = set(requires) + + assert isinstance( + requires, set + ), f"requires should be a set, but is {type(requires)}" + assert all([isinstance(r, str) for r in requires]), ( + "All elements of requires should be strings, but are " + + f"{[type(r) for r in requires]}" ) + return requires + def do_train(self, step): """ Check if training logs should be done at the current step. The decision is based @@ -371,13 +433,10 @@ def plot(self, x_sampled, kde_pred, kde_true, **plot_kwargs): "Reward KDE": fig_kde_true, } - def compute_log_prob_metrics(self, x_tt, gfn, dict_tt, requires=_sentinel): + def compute_log_prob_metrics(self, x_tt, gfn, metrics=None): gfn = self.gfn_agent - - if requires is None: - requires = set([r for m in self.metrics.values() for r in m["requires"]]) - if requires is _sentinel: - requires = self.requires + metrics = self.make_metrics(metrics) + requires = self.make_requires(metrics=metrics) logprobs_x_tt, logprobs_std, probs_std = gfn.estimate_logprobs_data( x_tt, @@ -389,22 +448,22 @@ def compute_log_prob_metrics(self, x_tt, gfn, dict_tt, requires=_sentinel): lp_metrics = {} - if "mean_logprobs_std" in self.metrics: + if "mean_logprobs_std" in metrics: lp_metrics["mean_logprobs_std"] = logprobs_std.mean().item() - if "mean_probs_std" in self.metrics: + if "mean_probs_std" in metrics: lp_metrics["mean_probs_std"] = probs_std.mean().item() if "reward_batch" in requires: rewards_x_tt = gfn.env.reward_batch(x_tt) - if "corr_prob_traj_rewards" in self.metrics: + if "corr_prob_traj_rewards" in metrics: rewards_x_tt = gfn.env.reward_batch(x_tt) lp_metrics["corr_prob_traj_rewards"] = np.corrcoef( np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt )[0, 1] - if "var_logrewards_logp" in self.metrics: + if "var_logrewards_logp" in metrics: rewards_x_tt = gfn.env.reward_batch(x_tt) lp_metrics["var_logrewards_logp"] = torch.var( torch.log( @@ -412,27 +471,28 @@ def compute_log_prob_metrics(self, x_tt, gfn, dict_tt, requires=_sentinel): ) - logprobs_x_tt ).item() - if "nll_tt" in self.metrics: + if "nll_tt" in metrics: lp_metrics["nll_tt"] = -logprobs_x_tt.mean().item() - if "logprobs_std_nll_ratio" in self.metrics: + if "logprobs_std_nll_ratio" in metrics: lp_metrics["logprobs_std_nll_ratio"] = ( -logprobs_std.mean() / logprobs_x_tt.mean() ) return lp_metrics - def eval(self, metrics=_sentinel, **plot_kwargs): + def eval(self, metrics=None, **plot_kwargs): """ Evaluate the GFlowNetAgent and compute metrics and plots. - If `metrics` is not provided, the evaluator's `metrics` attribute is - used (default). + If `metrics` is not provided, the evaluator's `self.metrics` attribute is used + (default). Parameters ---------- metrics : List[str], optional - List of metrics to compute, by default the evaluator's `metrics` attribute. + List of metrics to compute, by default the evaluator's `self.metrics` + attribute. plot_kwargs : dict, optional Additional keyword arguments to pass to the plotting methods. @@ -444,16 +504,10 @@ def eval(self, metrics=_sentinel, **plot_kwargs): logprobs_std_nll_ratio, figs, env_metrics] (should be refactored to dict) """ gfn = self.gfn_agent - all_metrics = {} + metrics = self.make_metrics(metrics) + requires = self.make_requires(metrics=metrics) - if metrics is None: - # TODO-V use this in the rest of the code to selectively compute metrics - metrics = set(METRICS.keys()) - requires = set([r for m in metrics for r in METRICS[m]["requires"]]) - - if metrics is _sentinel: - metrics = self.metrics - requires = self.requires + all_metrics = {} if gfn.buffer.test_pkl is None: result = { @@ -473,9 +527,7 @@ def eval(self, metrics=_sentinel, **plot_kwargs): # likelihood of the data according the the GFlowNet policy; and NLL. # TODO: organise code for better efficiency and readability if "log_probs" in requires: - lp_metrics = self.compute_log_prob_metrics( - x_tt, gfn, dict_tt, requires=requires - ) + lp_metrics = self.compute_log_prob_metrics(x_tt, gfn, metrics=metrics) all_metrics.update(lp_metrics) x_sampled = [] @@ -680,7 +732,7 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): return metrics, figs, summary - def eval_and_log(self, it, metrics=_sentinel): + def eval_and_log(self, it, metrics=None): """ Evaluate the GFlowNetAgent and log the results with its logger. From 46597071070bc1faf16bf3ad56e84e78fcbfeedc Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 12:35:57 -0500 Subject: [PATCH 025/106] refactor `requires` --- gflownet/evaluator/base.py | 86 +++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 44 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 1b86c6b14..9d5ccf667 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -22,39 +22,39 @@ METRICS = { "l1": { "name": "L1 error", - "requires": ["density"], + "requirements": ["density"], }, "kl": { "name": "KL Div.", - "requires": ["density"], + "requirements": ["density"], }, "jsd": { "name": "Jensen Shannon Div.", - "requires": ["density"], + "requirements": ["density"], }, "corr_prob_traj_rewards": { "name": "Corr. (test probs., rewards)", - "requires": ["log_probs", "reward_batch"], + "requirements": ["log_probs", "reward_batch"], }, "var_logrewards_logp": { "name": "Var(logR - logp) test", - "requires": ["log_probs", "reward_batch"], + "requirements": ["log_probs", "reward_batch"], }, "nll_tt": { "name": "NLL of test data", - "requires": ["log_probs"], + "requirements": ["log_probs"], }, "mean_logprobs_std": { "name": "Mean BS Std(logp)", - "requires": ["log_probs"], + "requirements": ["log_probs"], }, "mean_probs_std": { "name": "Mean BS Std(p)", - "requires": ["log_probs"], + "requirements": ["log_probs"], }, "logprobs_std_nll_ratio": { "name": "BS Std(logp) / NLL", - "requires": ["log_probs"], + "requirements": ["log_probs"], }, } @@ -86,11 +86,11 @@ def __init__(self, **kwargs): self.gfn_agent = kwargs.get("gfn_agent") self.config = self.gfn_agent.eval_config self.logger = self.gfn_agent.logger - self.requires = set() + self.reqs = set() - self.metrics = self.requires = _sentinel + self.metrics = self.reqs = _sentinel self.metrics = self.make_metrics(self.config.metrics) - self.requires = self.make_require() + self.reqs = self.make_requirements() def make_metrics(self, metrics=None): """ @@ -109,7 +109,7 @@ def make_metrics(self, metrics=None): ------- dict Dictionary of metrics to compute, with the metric names as keys and the - metric names and requires as values. + metric names and requirements as values. Raises ------ @@ -139,24 +139,24 @@ def make_metrics(self, metrics=None): return {m: METRICS[m] for m in metrics} - def make_requires(self, requires=None, metrics=None): + def make_requirements(self, reqs=None, metrics=None): """ Make requirements for the metrics to compute. 1. If `metrics` is provided, they must be as a dict of metrics. The requirements - are computed from the `requires` attribute of the metrics. + are computed from the `requirements` attribute of the metrics. - 2. Otherwise, the requirements are computed from the `requires` argument: - - If `requires` is `"all"`, all requirements of all metrics are computed. - - If `requires` is `None`, the evaluator's `self.requires` attribute is + 2. Otherwise, the requirements are computed from the `reqs` argument: + - If `reqs` is `"all"`, all requirements of all metrics are computed. + - If `reqs` is `None`, the evaluator's `self.reqs` attribute is used. - - If `requires` is a list, it is used as the requirements. + - If `reqs` is a list, it is used as the requirements. Parameters ---------- - requires : Union[str, List[str]], optional + reqs : Union[str, List[str]], optional The metrics requirements. Either `"all"`, a list of requirements or `None` - to use the evaluator's `self.requires` attribute. By default None + to use the evaluator's `self.reqs` attribute. By default None metrics : List[str], optional The list of metrics dicts to compute requirements for. By default None. @@ -167,28 +167,26 @@ def make_requires(self, requires=None, metrics=None): """ if metrics is not None: - return set([r for m in metrics.values() for r in m["requires"]]) - - if requires == "all": - requires = set([r for m in METRICS.values() for r in m["requires"]]) - if requires is None: - if self.requires is _sentinel: - self.requires = set( - [r for m in self.metrics.values() for r in m["requires"]] + return set([r for m in metrics.values() for r in m["requirements"]]) + + if reqs == "all": + reqs = set([r for m in METRICS.values() for r in m["requirements"]]) + if reqs is None: + if self.reqs is _sentinel: + self.reqs = set( + [r for m in self.metrics.values() for r in m["requirements"]] ) - requires = self.requires - if isinstance(requires, list): - requires = set(requires) - - assert isinstance( - requires, set - ), f"requires should be a set, but is {type(requires)}" - assert all([isinstance(r, str) for r in requires]), ( - "All elements of requires should be strings, but are " - + f"{[type(r) for r in requires]}" + reqs = self.reqs + if isinstance(reqs, list): + reqs = set(reqs) + + assert isinstance(reqs, set), f"reqs should be a set, but is {type(reqs)}" + assert all([isinstance(r, str) for r in reqs]), ( + "All elements of reqs should be strings, but are " + + f"{[type(r) for r in reqs]}" ) - return requires + return reqs def do_train(self, step): """ @@ -436,7 +434,7 @@ def plot(self, x_sampled, kde_pred, kde_true, **plot_kwargs): def compute_log_prob_metrics(self, x_tt, gfn, metrics=None): gfn = self.gfn_agent metrics = self.make_metrics(metrics) - requires = self.make_requires(metrics=metrics) + reqs = self.make_requirements(metrics=metrics) logprobs_x_tt, logprobs_std, probs_std = gfn.estimate_logprobs_data( x_tt, @@ -454,7 +452,7 @@ def compute_log_prob_metrics(self, x_tt, gfn, metrics=None): if "mean_probs_std" in metrics: lp_metrics["mean_probs_std"] = probs_std.mean().item() - if "reward_batch" in requires: + if "reward_batch" in reqs: rewards_x_tt = gfn.env.reward_batch(x_tt) if "corr_prob_traj_rewards" in metrics: @@ -505,7 +503,7 @@ def eval(self, metrics=None, **plot_kwargs): """ gfn = self.gfn_agent metrics = self.make_metrics(metrics) - requires = self.make_requires(metrics=metrics) + reqs = self.make_requirements(metrics=metrics) all_metrics = {} @@ -526,7 +524,7 @@ def eval(self, metrics=None, **plot_kwargs): # Compute correlation between the rewards of the test data and the log # likelihood of the data according the the GFlowNet policy; and NLL. # TODO: organise code for better efficiency and readability - if "log_probs" in requires: + if "log_probs" in reqs: lp_metrics = self.compute_log_prob_metrics(x_tt, gfn, metrics=metrics) all_metrics.update(lp_metrics) From 987ab7b60f907abc10128e3658c625a7ee52f629 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 12:58:43 -0500 Subject: [PATCH 026/106] typo -> `should_log_train` --- gflownet/evaluator/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 9d5ccf667..94cd63d28 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -188,7 +188,7 @@ def make_requirements(self, reqs=None, metrics=None): return reqs - def do_train(self, step): + def should_log_train(self, step): """ Check if training logs should be done at the current step. The decision is based on the `self.config.train.period` attribute. From 06ccc614c9067b7c1ab78481f5103a7e759d9ed0 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 12:59:07 -0500 Subject: [PATCH 027/106] `compute_density_metrics` for `eval()` --- gflownet/evaluator/base.py | 170 ++++++++++++++++++++++--------------- 1 file changed, 103 insertions(+), 67 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 94cd63d28..f3d701505 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -479,56 +479,14 @@ def compute_log_prob_metrics(self, x_tt, gfn, metrics=None): return lp_metrics - def eval(self, metrics=None, **plot_kwargs): - """ - Evaluate the GFlowNetAgent and compute metrics and plots. - - If `metrics` is not provided, the evaluator's `self.metrics` attribute is used - (default). - - Parameters - ---------- - metrics : List[str], optional - List of metrics to compute, by default the evaluator's `self.metrics` - attribute. - plot_kwargs : dict, optional - Additional keyword arguments to pass to the plotting methods. - - Returns - ------- - list - List of computed metrics and figures: [l1, kl, jsd, corr_prob_traj_rewards, - var_logrewards_logp, nll_tt, mean_logprobs_std, mean_probs_std, - logprobs_std_nll_ratio, figs, env_metrics] (should be refactored to dict) - """ + def compute_density_metrics(self, x_tt, dict_tt, metrics=None): gfn = self.gfn_agent metrics = self.make_metrics(metrics) - reqs = self.make_requirements(metrics=metrics) - - all_metrics = {} - - if gfn.buffer.test_pkl is None: - result = { - "metrics": { - k: getattr(gfn, k) if hasattr(gfn, k) else None for k in metrics - } - } - result["figs"] = {} - result["env_metrics"] = {} - return result + reqs = self.make_requirements(metrics=metrics) # TODO-V: unused for now, TBD - with open(gfn.buffer.test_pkl, "rb") as f: - dict_tt = pickle.load(f) - x_tt = dict_tt["x"] + density_metrics = {} + x_sampled = density_true = density_pred = None - # Compute correlation between the rewards of the test data and the log - # likelihood of the data according the the GFlowNet policy; and NLL. - # TODO: organise code for better efficiency and readability - if "log_probs" in reqs: - lp_metrics = self.compute_log_prob_metrics(x_tt, gfn, metrics=metrics) - all_metrics.update(lp_metrics) - - x_sampled = [] if gfn.buffer.test_type is not None and gfn.buffer.test_type == "all": batch, _ = gfn.sample_batch(n_forward=self.logger.test.n, train=False) assert batch.is_valid() @@ -550,6 +508,7 @@ def eval(self, metrics=None, **plot_kwargs): density_pred = np.array([hist[tuple(x)] / z_pred for x in x_tt]) log_density_true = np.log(density_true + 1e-8) log_density_pred = np.log(density_pred + 1e-8) + elif gfn.continuous and hasattr(gfn.env, "fit_kde"): batch, _ = gfn.sample_batch(n_forward=self.logger.test.n, train=False) assert batch.is_valid() @@ -590,39 +549,116 @@ def eval(self, metrics=None, **plot_kwargs): log_density_pred = scores_pred - logsumexp(scores_pred, axis=0) density_true = np.exp(log_density_true) density_pred = np.exp(log_density_pred) + else: # TODO: refactor + # TODO-V: remove? / deprecated? env_metrics = gfn.env.test(x_sampled) - return { - "metrics": { - "l1": gfn.l1, - "kl": gfn.kl, - "jsd": gfn.jsd, - **all_metrics, - }, - "figs": {}, - "env_metrics": env_metrics, - } + density_metrics["env_metrics"] = env_metrics + density_metrics["l1"] = gfn.l1 + density_metrics["kl"] = gfn.kl + density_metrics["jsd"] = gfn.jsd + density_metrics["x_sampled"] = x_sampled + density_metrics["kde_pred"] = kde_pred + density_metrics["kde_true"] = kde_true + return density_metrics + # L1 error - l1 = np.abs(density_pred - density_true).mean() + density_metrics["l1"] = np.abs(density_pred - density_true).mean() # KL divergence - kl = (density_true * (log_density_true - log_density_pred)).mean() + density_metrics["kl"] = ( + density_true * (log_density_true - log_density_pred) + ).mean() # Jensen-Shannon divergence log_mean_dens = np.logaddexp(log_density_true, log_density_pred) + np.log(0.5) - jsd = 0.5 * np.sum(density_true * (log_density_true - log_mean_dens)) - jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) + density_metrics["jsd"] = 0.5 * np.sum( + density_true * (log_density_true - log_mean_dens) + ) + density_metrics["jsd"] += 0.5 * np.sum( + density_pred * (log_density_pred - log_mean_dens) + ) + + density_metrics["x_sampled"] = x_sampled + density_metrics["kde_pred"] = kde_pred + density_metrics["kde_true"] = kde_true + + return density_metrics + + def eval(self, metrics=None, **plot_kwargs): + """ + Evaluate the GFlowNetAgent and compute metrics and plots. + + If `metrics` is not provided, the evaluator's `self.metrics` attribute is used + (default). + + Extand in subclasses to add more metrics and plots: + + ```python + def eval(self, metrics=None, **plot_kwargs): + result = super().eval(metrics=metrics, **plot_kwargs) + result["metrics"]["my_custom_metric"] = my_custom_metric_function() + result["figs"]["My custom plot"] = my_custom_plot_function() + return result + ``` + + Parameters + ---------- + metrics : List[str], optional + List of metrics to compute, by default the evaluator's `self.metrics` + attribute. + plot_kwargs : dict, optional + Additional keyword arguments to pass to the plotting methods. + + Returns + ------- + list + List of computed metrics and figures: [l1, kl, jsd, corr_prob_traj_rewards, + var_logrewards_logp, nll_tt, mean_logprobs_std, mean_probs_std, + logprobs_std_nll_ratio, figs, env_metrics] (should be refactored to dict) + """ + gfn = self.gfn_agent + metrics = self.make_metrics(metrics) + reqs = self.make_requirements(metrics=metrics) + + all_metrics = {} + x_sampled = kde_pred = kde_true = None + env_metrics = figs = {} + + if gfn.buffer.test_pkl is None: + result = { + "metrics": { + k: getattr(gfn, k) if hasattr(gfn, k) else None for k in metrics + }, + "figs": figs, + "env_metrics": env_metrics, + } + return result + + with open(gfn.buffer.test_pkl, "rb") as f: + dict_tt = pickle.load(f) + x_tt = dict_tt["x"] + + # Compute correlation between the rewards of the test data and the log + # likelihood of the data according the the GFlowNet policy; and NLL. + # TODO: organise code for better efficiency and readability + if "log_probs" in reqs: + lp_metrics = self.compute_log_prob_metrics(x_tt, gfn, metrics=metrics) + all_metrics.update(lp_metrics) + + if "density" in reqs: + density_metrics = self.compute_density_metrics(x_tt, gfn, metrics=metrics) + x_sampled = density_metrics.pop("x_sampled", x_sampled) + kde_pred = density_metrics.pop("kde_pred", kde_pred) + kde_true = density_metrics.pop("kde_true", kde_true) + env_metrics = density_metrics.pop("env_metrics", env_metrics) + all_metrics.update(density_metrics) figs = self.plot(x_sampled=x_sampled, kde_pred=kde_pred, kde_true=kde_true) return { - "metrics": { - "l1": l1, - "kl": kl, - "jsd": jsd, - **all_metrics, - }, + "metrics": all_metrics, "figs": figs, - "env_metrics": {}, + "env_metrics": env_metrics, } @torch.no_grad() From ee388025675144b735f3f85d1ead5fc15c0d2a43 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 14:24:08 -0500 Subject: [PATCH 028/106] add `eval:base` default --- config/main.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/main.yaml b/config/main.yaml index dd5d98edf..65aa6cfd1 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -6,6 +6,7 @@ defaults: - proxy: corners - logger: wandb - user: default + - eval: base # Device device: cuda From 14687550a7a350b13b6ce8dd4f647a3039f56f1d Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 14:24:18 -0500 Subject: [PATCH 029/106] update configs --- config/eval/base.yaml | 5 +++-- config/logger/base.yaml | 7 ------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/config/eval/base.yaml b/config/eval/base.yaml index 9908967e5..720f65dd4 100644 --- a/config/eval/base.yaml +++ b/config/eval/base.yaml @@ -14,8 +14,9 @@ logprobs_batch_size: 100 logprobs_bootstrap_size: 10000 # Maximum number of test data points to compute log likelihood probs. max_data_logprobs: 1e5 - +train_log_period: 1 +checkpoints_period: 1000 # List of metrics as per gflownet/eval/evaluator.py:METRICS_NAMES # Set to null for all of them # Values must be comma separated like `metrics: "l1, kl, js"` (spaces are optional) -metrics: null \ No newline at end of file +metrics: all \ No newline at end of file diff --git a/config/logger/base.yaml b/config/logger/base.yaml index 8ab28fdfa..0028eb8d9 100644 --- a/config/logger/base.yaml +++ b/config/logger/base.yaml @@ -6,19 +6,12 @@ do: project_name: "GFlowNet" -# Train metrics -train: - period: 1 - oracle: period: 100000 k: - 1 - 10 - 100 -# Policy model checkpoints -checkpoints: - period: 1000 # Log dir logdir: From b670c1456d66b5b7356d8f4668d8f4546fcfb2a7 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 14:24:36 -0500 Subject: [PATCH 030/106] move evaluator init later in gfna init --- gflownet/gflownet.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index d71336100..181f84594 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -130,9 +130,6 @@ def __init__( self.float = set_float_precision(float_precision) # Environment self.env = env - # Evaluator - self.eval_config = eval_config - self.evaluator = GFlowNetEvaluator.from_agent(self) # Continuous environments self.continuous = hasattr(self.env, "continuous") and self.env.continuous if self.continuous and optimizer.loss in ["flowmatch", "flowmatching"]: @@ -240,6 +237,11 @@ def __init__( ) else: self.opt, self.lr_scheduler, self.target = None, None, None + + # Evaluator + self.eval_config = eval_config + self.evaluator = GFlowNetEvaluator.from_agent(self) + self.n_train_steps = optimizer.n_train_steps self.batch_size = optimizer.batch_size self.batch_size_total = sum(self.batch_size.values()) From 06d4d0612aef706a42da27f0c7e1bcc4a19aae71 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 14:25:02 -0500 Subject: [PATCH 031/106] remove legacy `.test.` references --- gflownet/evaluator/base.py | 52 ++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index f3d701505..39e762f7c 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -206,10 +206,10 @@ def should_log_train(self, step): bool True if training should be done at the current step, False otherwise. """ - if self.config.train.period is None or self.config.train.period < 0: + if self.config.train_log_period is None or self.config.train_log_period < 0: return False else: - return not step % self.config.train.period + return not step % self.config.train_log_period def should_eval(self, step): """ @@ -232,12 +232,12 @@ def should_eval(self, step): bool True if testing should be done at the current step, False otherwise. """ - if self.config.test.period is None or self.config.test.period < 0: + if self.config.period is None or self.config.period < 0: return False - elif step == 1 and self.config.test.first_it: + elif step == 1 and self.config.first_it: return True else: - return not step % self.config.test.period + return not step % self.config.period def should_eval_top_k(self, step): """ @@ -258,13 +258,13 @@ def should_eval_top_k(self, step): bool True if top k plots and metrics should be done at the current step, False """ - if self.config.test.top_k is None or self.config.test.top_k < 0: + if self.config.top_k is None or self.config.top_k < 0: return False - if self.config.test.top_k_period is None or self.config.test.top_k_period < 0: + if self.config.top_k_period is None or self.config.top_k_period < 0: return False - return step == 2 or step % self.config.test.top_k_period == 0 + return step == 2 or step % self.config.top_k_period == 0 def do_oracle(self, step): """ @@ -306,10 +306,10 @@ def should_checkpoint(self, step): bool True if checkpoints should be done at the current step, False otherwise. """ - if self.checkpoints.period is None or self.checkpoints.period < 0: + if self.config.checkpoints_period is None or self.config.checkpoints_period < 0: return False else: - return not step % self.checkpoints.period + return not step % self.config.checkpoints_period @classmethod def from_dir( @@ -431,17 +431,17 @@ def plot(self, x_sampled, kde_pred, kde_true, **plot_kwargs): "Reward KDE": fig_kde_true, } - def compute_log_prob_metrics(self, x_tt, gfn, metrics=None): + def compute_log_prob_metrics(self, x_tt, metrics=None): gfn = self.gfn_agent metrics = self.make_metrics(metrics) reqs = self.make_requirements(metrics=metrics) logprobs_x_tt, logprobs_std, probs_std = gfn.estimate_logprobs_data( x_tt, - n_trajectories=self.logger.test.n_trajs_logprobs, - max_data_size=self.logger.test.max_data_logprobs, - batch_size=self.logger.test.logprobs_batch_size, - bs_num_samples=self.logger.test.logprobs_bootstrap_size, + n_trajectories=self.config.n_trajs_logprobs, + max_data_size=self.config.max_data_logprobs, + batch_size=self.config.logprobs_batch_size, + bs_num_samples=self.config.logprobs_bootstrap_size, ) lp_metrics = {} @@ -488,7 +488,7 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): x_sampled = density_true = density_pred = None if gfn.buffer.test_type is not None and gfn.buffer.test_type == "all": - batch, _ = gfn.sample_batch(n_forward=self.logger.test.n, train=False) + batch, _ = gfn.sample_batch(n_forward=self.config.n, train=False) assert batch.is_valid() x_sampled = batch.get_terminating_states() @@ -510,7 +510,7 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): log_density_pred = np.log(density_pred + 1e-8) elif gfn.continuous and hasattr(gfn.env, "fit_kde"): - batch, _ = gfn.sample_batch(n_forward=self.logger.test.n, train=False) + batch, _ = gfn.sample_batch(n_forward=self.config.n, train=False) assert batch.is_valid() x_sampled = batch.get_terminating_states() # TODO make it work with conditional env @@ -518,21 +518,21 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): x_tt = torch2np(gfn.env.states2proxy(x_tt)) kde_pred = gfn.env.fit_kde( x_sampled, - kernel=self.logger.test.kde.kernel, - bandwidth=self.logger.test.kde.bandwidth, + kernel=self.config.kde.kernel, + bandwidth=self.config.kde.bandwidth, ) if "log_density_true" in dict_tt and "kde_true" in dict_tt: log_density_true = dict_tt["log_density_true"] kde_true = dict_tt["kde_true"] else: # Sample from reward via rejection sampling - x_from_reward = gfn.env.sample_from_reward(n_samples=self.logger.test.n) + x_from_reward = gfn.env.sample_from_reward(n_samples=self.config.n) x_from_reward = torch2np(gfn.env.states2proxy(x_from_reward)) # Fit KDE with samples from reward kde_true = gfn.env.fit_kde( x_from_reward, - kernel=self.logger.test.kde.kernel, - bandwidth=self.logger.test.kde.bandwidth, + kernel=self.config.kde.kernel, + bandwidth=self.config.kde.bandwidth, ) # Estimate true log density using test samples # TODO: this may be specific-ish for the torus or not @@ -642,11 +642,13 @@ def eval(self, metrics=None, **plot_kwargs): # likelihood of the data according the the GFlowNet policy; and NLL. # TODO: organise code for better efficiency and readability if "log_probs" in reqs: - lp_metrics = self.compute_log_prob_metrics(x_tt, gfn, metrics=metrics) + lp_metrics = self.compute_log_prob_metrics(x_tt, metrics=metrics) all_metrics.update(lp_metrics) if "density" in reqs: - density_metrics = self.compute_density_metrics(x_tt, gfn, metrics=metrics) + density_metrics = self.compute_density_metrics( + x_tt, dict_tt, metrics=metrics + ) x_sampled = density_metrics.pop("x_sampled", x_sampled) kde_pred = density_metrics.pop("kde_pred", kde_pred) kde_true = density_metrics.pop("kde_true", kde_true) @@ -786,7 +788,7 @@ def eval_and_log(self, it, metrics=None): for m, v in result["metrics"].items(): setattr(gfn, m, v) - mertics_to_log = {METRICS[k]["name"]: v for k, v in result["metrics"].values()} + mertics_to_log = {METRICS[k]["name"]: v for k, v in result["metrics"].items()} self.logger.log_metrics(mertics_to_log, it, gfn.use_context) self.logger.log_metrics(result["env_metrics"], it, use_context=gfn.use_context) From 58092542edb101f380816436391547660045e328 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 14:25:28 -0500 Subject: [PATCH 032/106] debug print --- gflownet/utils/logger.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 9c41dd037..1525e1de7 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -324,6 +324,9 @@ def save_models( path = self.sf_ckpt_path.parent / stem torch.save(state_flow.model.state_dict(), path) + if self.debug: + print(f"Models saved at step {step} in {path}") + def log_time(self, times: dict, use_context: bool): if self.do.times: times = {"time_{}".format(k): v for k, v in times.items()} From 4a3b049ba58db63d0ede8d8d96c6339b7ddd46f3 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 14:26:51 -0500 Subject: [PATCH 033/106] fix logdir exists logic and `exit(1)` --- gflownet/utils/logger.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 1525e1de7..6e9fcf5ec 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -1,4 +1,5 @@ import os +import sys from datetime import datetime from pathlib import Path from typing import Union @@ -24,7 +25,6 @@ def __init__( project_name: str, logdir: dict, oracle: dict, - checkpoints: dict, progress: bool, lightweight: bool, debug: bool, @@ -37,7 +37,6 @@ def __init__( self.do = do self.do.times = self.do.times and self.do.online self.oracle = oracle - self.checkpoints = checkpoints slurm_job_id = os.environ.get("SLURM_JOB_ID") if run_name is None: @@ -70,11 +69,11 @@ def __init__( self.debug = debug # Log directory self.logdir = Path(logdir.root) - if self.logdir.exists() or logdir.overwrite: + if not self.logdir.exists() or logdir.overwrite: self.logdir.mkdir(parents=True, exist_ok=True) else: - # TODO: this message seems contradictory with the logic print(f"logdir {logdir} already exists! - Ending run...") + sys.exit(1) self.ckpts_dir = self.logdir / logdir.ckpts self.ckpts_dir.mkdir(parents=True, exist_ok=True) # Write wandb URL From 0cd21e0cdc47a5f111876b7c4cf5e9bb7088f2bb Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 14:57:02 -0500 Subject: [PATCH 034/106] trailing whitespace --- config/logger/wandb.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/logger/wandb.yaml b/config/logger/wandb.yaml index e04be706e..4614e99e9 100644 --- a/config/logger/wandb.yaml +++ b/config/logger/wandb.yaml @@ -3,5 +3,5 @@ defaults: _target_: gflownet.utils.logger.Logger -tags: +tags: - gflownet From 4802ab203f4b4c1384d18063cc4c7d30c93d8baa Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 19 Feb 2024 14:57:18 -0500 Subject: [PATCH 035/106] remove `oracle` references --- config/gflownet/gflownet.yaml | 3 --- config/logger/base.yaml | 7 ------- gflownet/evaluator/base.py | 35 ++--------------------------------- gflownet/gflownet.py | 4 ---- gflownet/utils/logger.py | 28 +++++++++++----------------- 5 files changed, 13 insertions(+), 64 deletions(-) diff --git a/config/gflownet/gflownet.yaml b/config/gflownet/gflownet.yaml index 388d3dfbe..d0acec1b1 100644 --- a/config/gflownet/gflownet.yaml +++ b/config/gflownet/gflownet.yaml @@ -52,8 +52,5 @@ replay_sampling: permutation # Train data set backward sampling train_sampling: permutation num_empirical_loss: 200000 -oracle: - # Number of samples for oracle metrics - n: 500 sample_only: False active_learning: False diff --git a/config/logger/base.yaml b/config/logger/base.yaml index 0028eb8d9..a16c9347f 100644 --- a/config/logger/base.yaml +++ b/config/logger/base.yaml @@ -6,13 +6,6 @@ do: project_name: "GFlowNet" -oracle: - period: 100000 - k: - - 1 - - 10 - - 100 - # Log dir logdir: root: ./logs diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 39e762f7c..041bfc56f 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -266,28 +266,6 @@ def should_eval_top_k(self, step): return step == 2 or step % self.config.top_k_period == 0 - def do_oracle(self, step): - """ - Check if oracle should be done at the current step. The decision is based on the - `self.config.oracle.period` attribute. - - Set `self.config.oracle.period` to `None` or a negative value to disable oracle. - - Parameters - ---------- - step : int - Current iteration step. - - Returns - ------- - bool - True if oracle should be done at the current step, False otherwise. - """ - if self.config.oracle.period is None or self.config.oracle.period < 0: - return False - else: - return not step % self.oracle.period - def should_checkpoint(self, step): """ Check if checkpoints should be done at the current step. The decision is based @@ -551,10 +529,6 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): density_pred = np.exp(log_density_pred) else: - # TODO: refactor - # TODO-V: remove? / deprecated? - env_metrics = gfn.env.test(x_sampled) - density_metrics["env_metrics"] = env_metrics density_metrics["l1"] = gfn.l1 density_metrics["kl"] = gfn.kl density_metrics["jsd"] = gfn.jsd @@ -614,7 +588,7 @@ def eval(self, metrics=None, **plot_kwargs): list List of computed metrics and figures: [l1, kl, jsd, corr_prob_traj_rewards, var_logrewards_logp, nll_tt, mean_logprobs_std, mean_probs_std, - logprobs_std_nll_ratio, figs, env_metrics] (should be refactored to dict) + logprobs_std_nll_ratio, figs] (should be refactored to dict) #TODO fix docstring """ gfn = self.gfn_agent metrics = self.make_metrics(metrics) @@ -622,7 +596,7 @@ def eval(self, metrics=None, **plot_kwargs): all_metrics = {} x_sampled = kde_pred = kde_true = None - env_metrics = figs = {} + figs = {} if gfn.buffer.test_pkl is None: result = { @@ -630,7 +604,6 @@ def eval(self, metrics=None, **plot_kwargs): k: getattr(gfn, k) if hasattr(gfn, k) else None for k in metrics }, "figs": figs, - "env_metrics": env_metrics, } return result @@ -652,7 +625,6 @@ def eval(self, metrics=None, **plot_kwargs): x_sampled = density_metrics.pop("x_sampled", x_sampled) kde_pred = density_metrics.pop("kde_pred", kde_pred) kde_true = density_metrics.pop("kde_true", kde_true) - env_metrics = density_metrics.pop("env_metrics", env_metrics) all_metrics.update(density_metrics) figs = self.plot(x_sampled=x_sampled, kde_pred=kde_pred, kde_true=kde_true) @@ -660,7 +632,6 @@ def eval(self, metrics=None, **plot_kwargs): return { "metrics": all_metrics, "figs": figs, - "env_metrics": env_metrics, } @torch.no_grad() @@ -783,7 +754,6 @@ def eval_and_log(self, it, metrics=None): List of metrics to compute, by default the evaluator's `metrics` attribute. """ gfn = self.gfn_agent - # TODO-V: do we need to set attributes? result = self.eval(metrics=metrics) for m, v in result["metrics"].items(): setattr(gfn, m, v) @@ -791,7 +761,6 @@ def eval_and_log(self, it, metrics=None): mertics_to_log = {METRICS[k]["name"]: v for k, v in result["metrics"].items()} self.logger.log_metrics(mertics_to_log, it, gfn.use_context) - self.logger.log_metrics(result["env_metrics"], it, use_context=gfn.use_context) self.logger.log_plots(result["figs"], it, use_context=gfn.use_context) def eval_and_log_top_k(self, it): diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 181f84594..6cb934211 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -47,7 +47,6 @@ def __init__( pct_offline, logger, num_empirical_loss, - oracle, eval_config, state_flow=None, active_learning=False, @@ -96,8 +95,6 @@ def __init__( (`gflownet/utils/logger.py:Logger`). num_empirical_loss : int Number of empirical loss samples to be used for training. - oracle : dict - Oracle config dictionary. See gflownet.yaml:oracle for details. eval_config : dict, optional Evaluator config dictionary. See `eval/base.yaml` for details. By default None. @@ -161,7 +158,6 @@ def __init__( # Logging self.num_empirical_loss = num_empirical_loss self.logger = logger - self.oracle_n = oracle.n # Buffers self.replay_sampling = replay_sampling self.train_sampling = train_sampling diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 6e9fcf5ec..ac8bd92fd 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -24,7 +24,6 @@ def __init__( do: dict, project_name: str, logdir: dict, - oracle: dict, progress: bool, lightweight: bool, debug: bool, @@ -36,7 +35,6 @@ def __init__( self.config = config self.do = do self.do.times = self.do.times and self.do.online - self.oracle = oracle slurm_job_id = os.environ.get("SLURM_JOB_ID") if run_name is None: @@ -267,18 +265,6 @@ def log_sampler_test( use_context=use_context, ) - def log_sampler_oracle(self, energies: array, step: int, use_context: bool): - # TODO-V -> remove? Unused - if not self.do.online: - return - if step.do_oracle(step): - energies_sorted = np.sort(energies) - dict_topk = {} - for k in self.oracle.k: - mean_topk = np.mean(energies_sorted[:k]) - dict_topk.update({"oracle_mean_top{}".format(k): mean_topk}) - self.log_metrics(dict_topk, use_context=use_context) - def log_losses( self, losses: list, @@ -303,12 +289,19 @@ def save_models( ): if final: ckpt_id = "final" + if self.debug: + print(f"Saving final models in {self.ckpts_dir}") else: ckpt_id = "_iter{:06d}".format(step) + if self.debug: + print(f"Saving models at step {step} in {self.ckpts_dir}") + if forward_policy.is_model and self.pf_ckpt_path is not None: stem = self.pf_ckpt_path.stem + self.context + ckpt_id + ".ckpt" path = self.pf_ckpt_path.parent / stem torch.save(forward_policy.model.state_dict(), path) + if self.debug: + print(f"Forward policy saved in {path}") if ( backward_policy and backward_policy.is_model @@ -317,14 +310,15 @@ def save_models( stem = self.pb_ckpt_path.stem + self.context + ckpt_id + ".ckpt" path = self.pb_ckpt_path.parent / stem torch.save(backward_policy.model.state_dict(), path) + if self.debug: + print(f"Backward policy saved in {path}") if state_flow is not None and self.sf_ckpt_path is not None: stem = self.sf_ckpt_path.stem + self.context + ckpt_id + ".ckpt" path = self.sf_ckpt_path.parent / stem torch.save(state_flow.model.state_dict(), path) - - if self.debug: - print(f"Models saved at step {step} in {path}") + if self.debug: + print(f"State flow saved in {path}") def log_time(self, times: dict, use_context: bool): if self.do.times: From d8c9a7b69b6b57b62ead6b97d8ad79b3bc3d2bab Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:40:01 -0500 Subject: [PATCH 036/106] add `eval` default --- config/tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/tests.yaml b/config/tests.yaml index 7f25a2ad6..d707c7f36 100644 --- a/config/tests.yaml +++ b/config/tests.yaml @@ -6,6 +6,7 @@ defaults: - policy: mlp - logger: base - user: alex + - eval: base # Device device: cpu From 272310c9c11f8fe90d08666d21e855bf1c28a807 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:40:45 -0500 Subject: [PATCH 037/106] `_self_` last to allow for overrides in `_self_` to other name spaces --- config/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/tests.yaml b/config/tests.yaml index d707c7f36..ba3346aea 100644 --- a/config/tests.yaml +++ b/config/tests.yaml @@ -1,5 +1,4 @@ defaults: - - _self_ - env: grid - gflownet: trajectorybalance - proxy: uniform @@ -7,6 +6,7 @@ defaults: - logger: base - user: alex - eval: base + - _self_ # Device device: cpu From 51607597c08f908af20fa266746ffa8be5c6b652 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:41:23 -0500 Subject: [PATCH 038/106] `name` -> `display_name` --- gflownet/evaluator/base.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 041bfc56f..f587d606a 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -21,39 +21,39 @@ METRICS = { "l1": { - "name": "L1 error", + "display_name": "L1 error", "requirements": ["density"], }, "kl": { - "name": "KL Div.", + "display_name": "KL Div.", "requirements": ["density"], }, "jsd": { - "name": "Jensen Shannon Div.", + "display_name": "Jensen Shannon Div.", "requirements": ["density"], }, "corr_prob_traj_rewards": { - "name": "Corr. (test probs., rewards)", + "display_name": "Corr. (test probs., rewards)", "requirements": ["log_probs", "reward_batch"], }, "var_logrewards_logp": { - "name": "Var(logR - logp) test", + "display_name": "Var(logR - logp) test", "requirements": ["log_probs", "reward_batch"], }, "nll_tt": { - "name": "NLL of test data", + "display_name": "NLL of test data", "requirements": ["log_probs"], }, "mean_logprobs_std": { - "name": "Mean BS Std(logp)", + "display_name": "Mean BS Std(logp)", "requirements": ["log_probs"], }, "mean_probs_std": { - "name": "Mean BS Std(p)", + "display_name": "Mean BS Std(p)", "requirements": ["log_probs"], }, "logprobs_std_nll_ratio": { - "name": "BS Std(logp) / NLL", + "display_name": "BS Std(logp) / NLL", "requirements": ["log_probs"], }, } @@ -758,7 +758,9 @@ def eval_and_log(self, it, metrics=None): for m, v in result["metrics"].items(): setattr(gfn, m, v) - mertics_to_log = {METRICS[k]["name"]: v for k, v in result["metrics"].items()} + mertics_to_log = { + METRICS[k]["display_name"]: v for k, v in result["metrics"].items() + } self.logger.log_metrics(mertics_to_log, it, gfn.use_context) self.logger.log_plots(result["figs"], it, use_context=gfn.use_context) From 975024e88a96736d97922a30b6f0a249842752d8 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:42:11 -0500 Subject: [PATCH 039/106] `ALL_REQS` and `ValueError`s --- gflownet/evaluator/base.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index f587d606a..437816039 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -57,6 +57,7 @@ "requirements": ["log_probs"], }, } +ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) class GFlowNetEvaluator: @@ -169,8 +170,14 @@ def make_requirements(self, reqs=None, metrics=None): if metrics is not None: return set([r for m in metrics.values() for r in m["requirements"]]) - if reqs == "all": - reqs = set([r for m in METRICS.values() for r in m["requirements"]]) + if isinstance(reqs, str): + if reqs == "all": + reqs = ALL_REQS.copy() + else: + raise ValueError( + "reqs should be 'all', a list of requirements or None, but is " + + f"{reqs}." + ) if reqs is None: if self.reqs is _sentinel: self.reqs = set( @@ -186,6 +193,10 @@ def make_requirements(self, reqs=None, metrics=None): + f"{[type(r) for r in reqs]}" ) + for r in reqs: + if r not in ALL_REQS: + raise ValueError(f"Unknown requirement: {r}") + return reqs def should_log_train(self, step): From a532eab577095edc8413b012dcb722a57f6a9689 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:42:34 -0500 Subject: [PATCH 040/106] missing tensor `.item()` --- gflownet/evaluator/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 437816039..b924b0c22 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -464,7 +464,7 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): if "logprobs_std_nll_ratio" in metrics: lp_metrics["logprobs_std_nll_ratio"] = ( -logprobs_std.mean() / logprobs_x_tt.mean() - ) + ).item() return lp_metrics From b3b7ff2543effc07775c030faaafd70d45a90d14 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:43:45 -0500 Subject: [PATCH 041/106] move `kde_pred` to continuous density metrics only --- gflownet/evaluator/base.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index b924b0c22..c4f2a8fce 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -539,13 +539,14 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): density_true = np.exp(log_density_true) density_pred = np.exp(log_density_pred) + density_metrics["kde_pred"] = kde_pred + density_metrics["kde_true"] = kde_true + else: density_metrics["l1"] = gfn.l1 density_metrics["kl"] = gfn.kl density_metrics["jsd"] = gfn.jsd density_metrics["x_sampled"] = x_sampled - density_metrics["kde_pred"] = kde_pred - density_metrics["kde_true"] = kde_true return density_metrics # L1 error @@ -564,9 +565,6 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): ) density_metrics["x_sampled"] = x_sampled - density_metrics["kde_pred"] = kde_pred - density_metrics["kde_true"] = kde_true - return density_metrics def eval(self, metrics=None, **plot_kwargs): From 610dcfcf31184a972411f598edd475f38f703407 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:44:06 -0500 Subject: [PATCH 042/106] store pkl & csv paths as `Buffer` attributes --- gflownet/utils/buffer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index 75edc4d44..fbdd6252d 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -43,6 +43,12 @@ def __init__( self.replay_trajs = {} self.replay_rewards = {} self.replay_pkl = "replay.pkl" + + self.train_csv = None + self.train_pkl = None + self.test_csv = None + self.test_pkl = None + self.save_replay() # Define train and test data sets if train is not None and "type" in train: @@ -56,6 +62,7 @@ def __init__( and train.output_csv is not None ): self.train.to_csv(train.output_csv) + self.train_csv = train.output_csv if ( dict_tr is not None and "output_pkl" in train @@ -74,6 +81,7 @@ def __init__( """ ) self.train_pkl = None + if test is not None and "type" in test: self.test_type = test.type else: @@ -84,6 +92,7 @@ def __init__( and "output_csv" in test and test.output_csv is not None ): + self.test_csv = test.output_csv self.test.to_csv(test.output_csv) if dict_tt is not None and "output_pkl" in test and test.output_pkl is not None: with open(test.output_pkl, "wb") as f: From c4294f6eee9af06883ee5c1e09696c6cadb11955 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:44:50 -0500 Subject: [PATCH 043/106] Imrpove robustness and allow `dict` metrics to `make_metrics` --- gflownet/evaluator/base.py | 88 ++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 23 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index c4f2a8fce..c2bd681b7 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -95,14 +95,20 @@ def __init__(self, **kwargs): def make_metrics(self, metrics=None): """ - Parse metrics from a list, a string or None. + Parse metrics from a dict, list, a string or None. - If `None`, all metrics are computed. If a string, it can be a comma-separated - list of metric names, with or without spaces. All metrics must be in `METRICS`. + - If `None`, all metrics are selected. + - If a string, it can be a comma-separated list of metric names, with or without + spaces. + - If a list, it should be a list of metric names (keys of `METRICS`). + - If a dict, its keys should be metric names and its values will be ignored: + they will be assigned from `METRICS`. + + All metrics must be in `METRICS`. Parameters ---------- - metrics : (Union[str, List[str]], optional) + metrics : Union[str, List[str]], optional Metrics to compute when running the `evaluator.eval()` function. Defaults to None, i.e. all metrics in `METRICS` are computed. @@ -110,16 +116,13 @@ def make_metrics(self, metrics=None): ------- dict Dictionary of metrics to compute, with the metric names as keys and the - metric names and requirements as values. + metric display names and requirements as values. Raises ------ ValueError If a metric name is not in `METRICS`. """ - if metrics == "all": - metrics = METRICS.keys() - if metrics is None: assert self.metrics is not _sentinel, ( "Error setting self.metrics. This is likely due to the `metrics:`" @@ -128,12 +131,31 @@ def make_metrics(self, metrics=None): ) return self.metrics + if not isinstance(metrics, (str, list, dict)): + raise ValueError( + "metrics should be None, a string, a list or a dict," + + f" but is {type(metrics)}." + ) + + if metrics == "all": + metrics = METRICS.keys() + if isinstance(metrics, str): + if metrics == "": + raise ValueError( + "`metrics` should not be an empty string. " + + "Set to 'all' or a list of metric names or None (null in YAML)." + ) if "," in metrics: - metrics = [m.strip() for m in metrics.split(",")] + metrics = metrics.split(",") else: metrics = [metrics] + if isinstance(metrics, dict): + metrics = metrics.keys() + + metrics = [m.strip() for m in metrics] + for m in metrics: if m not in METRICS: raise ValueError(f"Unknown metric name: {m}") @@ -149,8 +171,7 @@ def make_requirements(self, reqs=None, metrics=None): 2. Otherwise, the requirements are computed from the `reqs` argument: - If `reqs` is `"all"`, all requirements of all metrics are computed. - - If `reqs` is `None`, the evaluator's `self.reqs` attribute is - used. + - If `reqs` is `None`, the evaluator's `self.reqs` attribute is used. - If `reqs` is a list, it is used as the requirements. Parameters @@ -158,8 +179,9 @@ def make_requirements(self, reqs=None, metrics=None): reqs : Union[str, List[str]], optional The metrics requirements. Either `"all"`, a list of requirements or `None` to use the evaluator's `self.reqs` attribute. By default None - metrics : List[str], optional - The list of metrics dicts to compute requirements for. By default None. + metrics : Union[str, List[str], dict], optional + The metrics to compute requirements for. If not a dict, will be passed to + `make_metrics`. By default None. Returns ------- @@ -168,6 +190,11 @@ def make_requirements(self, reqs=None, metrics=None): """ if metrics is not None: + if not isinstance(metrics, dict): + metrics = self.make_metrics(metrics) + for m in metrics: + if m not in METRICS: + raise ValueError(f"Unknown metric name: {m}") return set([r for m in metrics.values() for r in m["requirements"]]) if isinstance(reqs, str): @@ -180,6 +207,12 @@ def make_requirements(self, reqs=None, metrics=None): ) if reqs is None: if self.reqs is _sentinel: + if not isinstance(self.metrics, dict): + raise ValueError( + "Cannot compute requirements from `None` without the `metrics`" + + " argument or the `self.metrics` attribute set to a dict" + + " of metrics." + ) self.reqs = set( [r for m in self.metrics.values() for r in m["requirements"]] ) @@ -187,7 +220,10 @@ def make_requirements(self, reqs=None, metrics=None): if isinstance(reqs, list): reqs = set(reqs) - assert isinstance(reqs, set), f"reqs should be a set, but is {type(reqs)}" + assert isinstance( + reqs, set + ), f"reqs should be None, 'all', a set or a list, but is {type(reqs)}" + assert all([isinstance(r, str) for r in reqs]), ( "All elements of reqs should be strings, but are " + f"{[type(r) for r in reqs]}" @@ -215,12 +251,12 @@ def should_log_train(self, step): Returns ------- bool - True if training should be done at the current step, False otherwise. + True if train logging should be done at the current step, False otherwise. """ - if self.config.train_log_period is None or self.config.train_log_period < 0: + if self.config.train_log_period is None or self.config.train_log_period <= 0: return False else: - return not step % self.config.train_log_period + return step % self.config.train_log_period == 0 def should_eval(self, step): """ @@ -243,12 +279,12 @@ def should_eval(self, step): bool True if testing should be done at the current step, False otherwise. """ - if self.config.period is None or self.config.period < 0: + if self.config.period is None or self.config.period <= 0: return False elif step == 1 and self.config.first_it: return True else: - return not step % self.config.period + return step % self.config.period == 0 def should_eval_top_k(self, step): """ @@ -269,13 +305,16 @@ def should_eval_top_k(self, step): bool True if top k plots and metrics should be done at the current step, False """ - if self.config.top_k is None or self.config.top_k < 0: + if self.config.top_k is None or self.config.top_k <= 0: return False - if self.config.top_k_period is None or self.config.top_k_period < 0: + if self.config.top_k_period is None or self.config.top_k_period <= 0: return False - return step == 2 or step % self.config.top_k_period == 0 + if step == 1 and self.config.first_it: + return True + + return step % self.config.top_k_period == 0 def should_checkpoint(self, step): """ @@ -295,7 +334,10 @@ def should_checkpoint(self, step): bool True if checkpoints should be done at the current step, False otherwise. """ - if self.config.checkpoints_period is None or self.config.checkpoints_period < 0: + if ( + self.config.checkpoints_period is None + or self.config.checkpoints_period <= 0 + ): return False else: return not step % self.config.checkpoints_period From 0af5719527ff5e44ebb01bf370df11b776787504 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:45:06 -0500 Subject: [PATCH 044/106] utils for tests file --- tests/utils_for_tests.py | 94 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/utils_for_tests.py diff --git a/tests/utils_for_tests.py b/tests/utils_for_tests.py new file mode 100644 index 000000000..08d34ea32 --- /dev/null +++ b/tests/utils_for_tests.py @@ -0,0 +1,94 @@ +import contextlib +import os +import tempfile +from pathlib import Path + +from hydra import compose, initialize + + +def find_root(): + """ + Find the root of the repository by looking for the .git folder and the gflownet + folder in the parent directories. + + Returns + ------- + Path + The root of the repository as a pathlib.Path object. + + Raises + ------ + RuntimeError + If the root of the repository could not be found. + """ + path = Path(__file__).resolve() + while not ( + (path / ".git").exists() + and (path / "gflownet").exists() + and (path / "config").exists() + and path != path.parent + ): + path = path.parent + if path == path.parent: + raise RuntimeError("Could not find root of the repository") + return path + + +REPO_ROOT = find_root() + + +def load_base_test_config(overrides=[]): + """ + Load the base test configuration in config/tests.yaml. + Simulate command-line args with overrides. + + Examples + -------- + >>> load_base_test_config(["env=grid", "env.buffer.test=None"]) + + Parameters + ---------- + overrides : list[str], optional + A list of overrides for the configuration, by default []. + + Returns + ------- + OmegaConf + The configuration as an OmegaConf object. + """ + with initialize( + version_base="1.1", + config_path=os.path.relpath( + str(REPO_ROOT / "config"), start=str(Path(__file__).parent) + ), + job_name="xxx", + ): + config = compose(config_name="tests", overrides=overrides) + return config + + +@contextlib.contextmanager +def ch_tmpdir(disable=False): + """ + Change to a temporary directory and change back to the original directory when the + context manager exits. + + Parameters + ---------- + disable : bool, optional + Whether to disable the context manager, by default False. + + Yields + ------ + str + The path of the temporary directory (if not disabled) or the original directory. + """ + d = os.getcwd() + with tempfile.TemporaryDirectory() as tmpdirname: + if not disable: + os.chdir(tmpdirname) + try: + yield tmpdirname if not disable else d + + finally: + os.chdir(d) From d51b4b6aa5147d39f389b7ef2392ba9b9132f1e5 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:45:55 -0500 Subject: [PATCH 045/106] + `gflownet_from_config` --- gflownet/utils/common.py | 58 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 8dccb6698..80825a89e 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -223,6 +223,64 @@ def find_latest_checkpoint(ckpt_dir, ckpt_name): return sorted(ckpts, key=lambda f: float(f.stem.split("iter")[1]))[-1] +def gflownet_from_config(config): + """ + Create GFlowNet from a Hydra OmegaConf config. + + Parameters + ---------- + config : DictConfig + Config. + + Returns + ------- + GFN + GFlowNet. + """ + # Logger + logger = instantiate(config.logger, config, _recursive_=False) + # The proxy is required in the env for scoring: might be an oracle or a model + proxy = instantiate( + config.proxy, + device=config.device, + float_precision=config.float_precision, + ) + # The proxy is passed to env and used for computing rewards + env = instantiate( + config.env, + proxy=proxy, + device=config.device, + float_precision=config.float_precision, + ) + forward_config = parse_policy_config(config, kind="forward") + backward_config = parse_policy_config(config, kind="backward") + forward_policy = instantiate( + forward_config, + env=env, + device=config.device, + float_precision=config.float_precision, + ) + backward_policy = instantiate( + backward_config, + env=env, + device=config.device, + float_precision=config.float_precision, + base=forward_policy, + ) + gflownet = instantiate( + config.gflownet, + device=config.device, + float_precision=config.float_precision, + env=env, + buffer=config.env.buffer, + forward_policy=forward_policy, + backward_policy=backward_policy, + logger=logger, + eval_config=config.eval, + ) + return gflownet + + def load_gflow_net_from_run_path( run_path, no_wandb=True, From 25db913999fd830130523db4f9f86324d54f2cd8 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:46:17 -0500 Subject: [PATCH 046/106] generic fixtures --- tests/conftest.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..60975c006 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,90 @@ +import sys + +import pytest +from utils_for_tests import REPO_ROOT, ch_tmpdir, load_base_test_config + +sys.path.append(str(REPO_ROOT)) + +from gflownet.utils.common import gflownet_from_config + + +@pytest.fixture +def config_for_tests(request): + """ + Load the base test configuration in config/tests.yaml. + + Simulate Hydra command-line args with overrides. + + Examples + -------- + ```python + @pytest.mark.parametrize( + "config_for_tests,something_else", + [ + (["env=grid", "env.buffer.test=None"], 1), + (["env=ctorus"], 2), + ], + indirect=["config_for_tests"], + ) + def test_my_func(config_for_tests, something_else): + ... + ``` + + Parameters + ---------- + request : FixtureRequest + The request object from pytest. Its `param` attribute is used to pass the + argument to the fixture. + + Returns + ------- + OmegaConf + The configuration as an OmegaConf object with the Hydra overrides applied. + """ + overrides = [] + if hasattr(request, "param") and request.param is not None: + overrides = request.param + if not isinstance(overrides, list): + overrides = [overrides] + assert isinstance(overrides, list), "Overrides must be a list." + assert all( + isinstance(ov, str) for ov in overrides + ), "Overrides must be a list of string." + + config = load_base_test_config(overrides=overrides) + return config + + +@pytest.fixture +def gflownet_for_tests(config_for_tests): + """ + Create a GFlowNet object from the configuration for tests. + + This is a generator so the code after `yield` is executed at the end of the test + which uses this fixture (akin to a `finally` block in a `try` statement or a + unittest `tearDown` method). + + By default, the execution is moved to a temporary directory to avoid polluting the + current directory with files written by the GFlowNetAgent. + + Set the `disable` parameter to `True` to avoid moving the execution to a temporary + directory (for example, when developing tests and wanting to inspect the files). + + Parameters + ---------- + config_for_tests : OmegaConf + The configuration for the GFlowNetAgent to be created. + + Yields + ------ + GFlowNetAgent + The loaded GFlowNetAgent object. + """ + # Move execution to a temporary directory if disable is not True + with ch_tmpdir(disable=False) as tmpdir: + print(f"Current GFlowNetAgent execution directory: {tmpdir}") + gfn = gflownet_from_config(config_for_tests) + yield gfn + + # Any teardown (=post-test) code goes here + pass From acc56f85f06325a58c6d7a45a6e46b89e52865c9 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:46:55 -0500 Subject: [PATCH 047/106] first tests for `gflownet.eval.base.GFlowNetEvaluator` --- tests/gflownet/eval/test_base.py | 252 +++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 tests/gflownet/eval/test_base.py diff --git a/tests/gflownet/eval/test_base.py b/tests/gflownet/eval/test_base.py new file mode 100644 index 000000000..a62c3162d --- /dev/null +++ b/tests/gflownet/eval/test_base.py @@ -0,0 +1,252 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +from omegaconf import OmegaConf + +from gflownet.evaluator.base import METRICS, GFlowNetEvaluator, _sentinel + +PERIOD_STEP_TARGET = [ + (0, 0, False), + (0, 1, False), + (0, 2, False), + (-1, 0, False), + (-1, 1, False), + (-1, 2, False), + (None, 0, False), + (None, 1, False), + (None, 2, False), + (1, 0, True), + (1, 1, True), + (1, 2, True), + (2, 0, True), + (2, 1, False), + (2, 2, True), + (3, 0, True), + (3, 1, False), + (3, 2, False), + (3, 3, True), +] + +CONSTANT_EVALUATOR = GFlowNetEvaluator( + gfn_agent=OmegaConf.create({"eval_config": {"metrics": "all"}, "logger": {}}), + sentinel=_sentinel, +) + + +@pytest.fixture +def dummy_evaluator(config_for_tests): + gfna_dummy = OmegaConf.create( + { + "eval_config": config_for_tests.eval, + "logger": config_for_tests.logger, + } + ) + return GFlowNetEvaluator(gfn_agent=gfna_dummy, sentinel=_sentinel) + + +@pytest.fixture +def constant_evaluator(): + CONSTANT_EVALUATOR.config = OmegaConf.create( + {"eval_config": {"metrics": "all"}, "logger": {}} + ) + return CONSTANT_EVALUATOR + + +@pytest.fixture +def all_reqs(): + return set([r for m in METRICS.values() for r in m["requirements"]]) + + +def test__make_metrics__all(dummy_evaluator): + assert dummy_evaluator.make_metrics("all") == METRICS + + +def test__make_metrics__str(dummy_evaluator): + assert dummy_evaluator.make_metrics("l1,kl") == { + k: METRICS[k] for k in ["l1", "kl"] + } + assert dummy_evaluator.make_metrics(" l1, kl") == { + k: METRICS[k] for k in ["l1", "kl"] + } + with pytest.raises(ValueError, match="Unknown metric name.*"): + dummy_evaluator.make_metrics("invalid") + with pytest.raises(ValueError, match="Unknown metric name.*"): + dummy_evaluator.make_metrics("l1,kj") + + with pytest.raises(ValueError, match=".*should not be an empty string.*"): + dummy_evaluator.make_metrics("") + + +def test__make_metrics__list(dummy_evaluator): + assert dummy_evaluator.make_metrics(["l1", "kl"]) == { + k: METRICS[k] for k in ["l1", "kl"] + } + assert dummy_evaluator.make_metrics([" l1", " kl"]) == { + k: METRICS[k] for k in ["l1", "kl"] + } + with pytest.raises(ValueError, match="Unknown metric name.*"): + dummy_evaluator.make_metrics(["invalid"]) + + +def test__make_metrics__None(dummy_evaluator): + assert dummy_evaluator.make_metrics() is dummy_evaluator.metrics + + with pytest.raises(AssertionError): + dummy_evaluator.metrics = _sentinel + dummy_evaluator.make_metrics() + + +def test__make_metrics__dict(dummy_evaluator): + assert dummy_evaluator.make_metrics({"l1": METRICS["l1"], "kl": METRICS["kl"]}) == { + k: METRICS[k] for k in ["l1", "kl"] + } + with pytest.raises(ValueError, match="Unknown metric name.*"): + dummy_evaluator.make_metrics( + { + "l1": METRICS["l1"], + "invalid": {"name": "invalid", "requirements": ["anything"]}, + } + ) + + +def test__make_metrics__other(dummy_evaluator): + with pytest.raises(ValueError, match="metrics should be None, a string.*"): + dummy_evaluator.make_metrics(1) + + with pytest.raises(ValueError, match="metrics should be None, a string.*"): + dummy_evaluator.make_metrics(1.0) + + with pytest.raises(ValueError, match="metrics should be None, a string.*"): + dummy_evaluator.make_metrics({1, 2, 3}) + + with pytest.raises(ValueError, match="metrics should be None, a string.*"): + dummy_evaluator.make_metrics(_sentinel) + + +def test__make_requirements__all(dummy_evaluator, all_reqs): + assert dummy_evaluator.make_requirements("all") == all_reqs + + +def test__make_requirements__str(dummy_evaluator, all_reqs): + with pytest.raises(ValueError, match="reqs should be 'all'.*"): + dummy_evaluator.make_requirements("") + + with pytest.raises(ValueError, match="reqs should be 'all'.*"): + dummy_evaluator.make_requirements("any_str_but_all") + + +def test__make_requirements__list(dummy_evaluator, all_reqs): + with pytest.raises(ValueError, match="Unknown requirement.*"): + dummy_evaluator.make_requirements(["l1", "kl"]) + + assert dummy_evaluator.make_requirements(list(all_reqs)) == all_reqs + + sub = list(all_reqs)[:2] + assert dummy_evaluator.make_requirements(sub) == set(sub) + + dummy_evaluator.make_requirements(metrics=["l1", "corr_prob_traj_rewards"]) == set( + r for r in all_reqs if r in ["l1", "corr_prob_traj_rewards"] + ) + + +def test__make_requirements__dict(dummy_evaluator, all_reqs): + assert all( + r in all_reqs + for r in dummy_evaluator.make_requirements( + metrics={k: METRICS[k] for k in ["l1", "corr_prob_traj_rewards"]} + ) + ) + with pytest.raises(ValueError, match="Unknown metric name.*"): + dummy_evaluator.make_requirements( + metrics={ + "l1": METRICS["l1"], + "invalid": {"name": "invalid", "requirements": ["anything"]}, + } + ) + with pytest.raises(AssertionError, match="reqs should be None, 'all'.*"): + dummy_evaluator.make_requirements(METRICS) + + +def test__make_requirements__None(dummy_evaluator, all_reqs): + assert dummy_evaluator.make_requirements() == all_reqs + + with pytest.raises(ValueError, match="Cannot compute requirements from.*"): + dummy_evaluator.reqs = _sentinel + dummy_evaluator.metrics = _sentinel + dummy_evaluator.make_requirements() + + +@pytest.mark.parametrize("period,step,target", PERIOD_STEP_TARGET) +def test__should_log_train(constant_evaluator, period, step, target): + constant_evaluator.config.train_log_period = period + assert constant_evaluator.should_log_train(step) is target + + +@pytest.mark.parametrize("period,step,target", PERIOD_STEP_TARGET) +def test__should_checkpoint(constant_evaluator, period, step, target): + constant_evaluator.config.checkpoints_period = period + assert constant_evaluator.should_checkpoint(step) is target + + +@pytest.mark.parametrize("first_it", [True, False]) +@pytest.mark.parametrize("period,step,target", PERIOD_STEP_TARGET) +def test__should_eval(constant_evaluator, period, step, target, first_it): + constant_evaluator.config.period = period + constant_evaluator.config.first_it = first_it + + if step == 1 and first_it and period and period > 0: + target = True + + assert constant_evaluator.should_eval(step) is target + + +@pytest.mark.parametrize("top_k", [None, -1, 0, 1, 2]) +@pytest.mark.parametrize("first_it", [True, False]) +@pytest.mark.parametrize("period,step,target", PERIOD_STEP_TARGET) +def test__should_eval_top_k(constant_evaluator, period, step, target, first_it, top_k): + constant_evaluator.config.top_k_period = period + constant_evaluator.config.top_k = top_k + constant_evaluator.config.first_it = first_it + + if first_it and period and period > 0 and step == 1: + target = True + + if not top_k or top_k <= 0: + target = False + + assert constant_evaluator.should_eval_top_k(step) is target + + +@pytest.mark.parametrize( + "config_for_tests,parameterization", + [ + (None, "default"), + (["env.length=4"], "grid_length_4"), + (["env=ctorus"], "ctorus"), + ], + indirect=[ + # overrides arg for conftest.py::config_for_tests fixture + "config_for_tests" + ], +) +def test__eval(gflownet_for_tests, parameterization): + assert Path("./replay.pkl").exists() + # results: {"metrics": dict[str, float], "figs": list[plt.Figure]} + results = gflownet_for_tests.evaluator.eval() + breakpoint() + + for k, v in results["metrics"].items(): + assert isinstance(k, str) + assert isinstance(v, float) + + if parameterization == "default": + pass + elif parameterization == "grid_length_4": + pass + elif parameterization == "ctorus": + for figname, fig in results["figs"].items(): + assert isinstance(figname, str) + assert isinstance(fig, plt.Figure) + else: + raise ValueError(f"Unknown parameterization: {parameterization}") From 93acac124c3e9312efb981115a5a649996e01f0f Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:50:18 -0500 Subject: [PATCH 048/106] clean up `oracle` files and `legacy.py` --- gflownet/oracle/__init__.py | 3 - gflownet/oracle/molecule.py | 44 -- gflownet/utils/common.py | 4 +- gflownet/utils/legacy.py | 979 ------------------------------------ gflownet/utils/oracle.py | 629 ----------------------- main.py | 2 +- scripts/oracle_annotate.py | 44 -- scripts/oracle_sampler.py | 140 ------ 8 files changed, 3 insertions(+), 1842 deletions(-) delete mode 100644 gflownet/oracle/__init__.py delete mode 100644 gflownet/oracle/molecule.py delete mode 100644 gflownet/utils/legacy.py delete mode 100644 gflownet/utils/oracle.py delete mode 100644 scripts/oracle_annotate.py delete mode 100644 scripts/oracle_sampler.py diff --git a/gflownet/oracle/__init__.py b/gflownet/oracle/__init__.py deleted file mode 100644 index fa0699a12..000000000 --- a/gflownet/oracle/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -``gflownet.oracle`` package docstring: **todo** -""" diff --git a/gflownet/oracle/molecule.py b/gflownet/oracle/molecule.py deleted file mode 100644 index 40d82365b..000000000 --- a/gflownet/oracle/molecule.py +++ /dev/null @@ -1,44 +0,0 @@ -import numpy as np -import numpy.typing as npt -import torch -from xtb.interface import Calculator, Param, XTBException -from xtb.libxtb import VERBOSITY_MUTED - -from gflownet.proxy.base import Proxy - - -class XTBMoleculeEnergy(Proxy): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def __call__(self, states_proxy): - # todo: probably make it parallel with mpi - return torch.tensor( - [self.get_energy(*st) for st in states_proxy], - dtype=self.float, - device=self.device, - ) - - def get_energy( - self, - atom_positions: npt.NDArray[np.float32], - atomic_numbers: npt.NDArray[np.int64], - ) -> float: - """ - Compute energy of a molecule defined by atom_positions and atomic_numbers - """ - calc = Calculator(Param.GFN2xTB, atomic_numbers, atom_positions) - calc.set_verbosity(VERBOSITY_MUTED) - try: - return calc.singlepoint().get_energy() - except XTBException: - return np.nan - - -if __name__ == "__main__": - from gflownet.utils.molecule.conformer_base import get_dummy_ad_conf_base - - conf = get_dummy_ad_conf_base() - proxy = XTBMoleculeEnergy() - energy = proxy.get_energy(conf.get_atom_positions(), conf.get_atomic_numbers()) - print("energy", energy) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 80825a89e..478c91cf8 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -239,7 +239,7 @@ def gflownet_from_config(config): """ # Logger logger = instantiate(config.logger, config, _recursive_=False) - # The proxy is required in the env for scoring: might be an oracle or a model + # The proxy is required in the env for scoring proxy = instantiate( config.proxy, device=config.device, @@ -326,7 +326,7 @@ def load_gflow_net_from_run_path( # Logger logger = instantiate(config.logger, config, _recursive_=False) - # The proxy is required in the env for scoring: might be an oracle or a model + # The proxy is required in the env for scoring proxy = instantiate( config.proxy, device=config.device, diff --git a/gflownet/utils/legacy.py b/gflownet/utils/legacy.py deleted file mode 100644 index 33fb12da4..000000000 --- a/gflownet/utils/legacy.py +++ /dev/null @@ -1,979 +0,0 @@ -"""import statement""" - -import os -import time -from argparse import Namespace -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import yaml - -""" -This is a general utilities file for the active learning pipeline - -To-Do: -""" - - -def get_config(args, override_args, args2config): - """ - Combines YAML configuration file, command line arguments and default arguments into - a single configuration dictionary. - - - Values in YAML file override default values - - Command line arguments override values in YAML file - - Returns - ------- - Namespace - """ - - def _update_config(arg, val, config, override=False): - config_aux = config - for k in args2config[arg]: - if k not in config_aux: - if k is args2config[arg][-1]: - config_aux.update({k: val}) - else: - config_aux.update({k: {}}) - config_aux = config_aux[k] - else: - if k is args2config[arg][-1] and override: - config_aux[k] = val - else: - config_aux = config_aux[k] - - # Read YAML config - if args.yaml_config: - yaml_path = Path(args.yaml_config) - assert yaml_path.exists(), "yaml_config = {}".format(args.yaml_config) - assert yaml_path.suffix in {".yaml", ".yml"} - with yaml_path.open("r") as f: - config = yaml.safe_load(f) - else: - config = {} - # Add args to config: add if not provided; override if in command line - override_args = [ - arg.strip("--").split("=")[0] for arg in override_args if "--" in arg - ] - override_args_extra = [] - for k1 in override_args: - if k1 in args2config: - v1 = args2config[k1] - for k2, v2 in args2config.items(): - if v2 == v1 and k2 != k1: - override_args_extra.append(k2) - override_args = override_args + override_args_extra - for k, v in vars(args).items(): - if k in override_args: - _update_config(k, v, config, override=True) - else: - _update_config(k, v, config, override=False) - return dict2namespace(config) - - -def printRecord(statement): - """ - print a string to command line output and a text file - :param statement: - :return: - """ - print(statement) - if os.path.exists("record.txt"): - with open("record.txt", "a") as file: - file.write("\n" + statement) - else: - with open("record.txt", "w") as file: - file.write("\n" + statement) - - -def letters2numbers(sequences): # Tranforming letters to numbers: - """ - Converts ATCG sequences to numerical values - :param sequences: ATCG-format DNA sequences to be converted - :return: DNA sequences in 1234 format - """ - - my_seq = np.zeros((len(sequences), len(sequences[0]))) - row = 0 - - for seq in sequences: - assert (type(seq) == str) and ( - len(seq) == my_seq.shape[1] - ), "Function inputs must be a list of equal length strings" - col = 0 - for na in seq: - if (na == "a") or (na == "A"): - my_seq[row, col] = 1 - elif (na == "u") or (na == "U") or (na == "t") or (na == "T"): - my_seq[row, col] = 2 - elif (na == "c") or (na == "C"): - my_seq[row, col] = 3 - elif (na == "g") or (na == "G"): - my_seq[row, col] = 4 - col += 1 - row += 1 - - return my_seq - - -def numbers2letters(sequences): # Tranforming letters to numbers: - """ - Converts numerical values to ATGC-format - :param sequences: numerical DNA sequences to be converted - :return: DNA sequences in ATGC format - """ - if type(sequences) != np.ndarray: - sequences = np.asarray(sequences) - - if sequences.ndim < 2: - sequences = np.expand_dims(sequences, 0) - - my_seq = ["" for x in range(len(sequences))] - row = 0 - for j in range(len(sequences)): - seq = sequences[j, :] - assert ( - type(seq) != str - ), "Function inputs must be a list of equal length strings" - for i in range(len(sequences[0])): - na = seq[i] - if na == 1: - my_seq[row] += "A" - elif na == 2: - my_seq[row] += "T" - elif na == 3: - my_seq[row] += "C" - elif na == 4: - my_seq[row] += "G" - row += 1 - return my_seq - - -def getModelName(ensembleIndex): - """ - :param params: parameters of the pipeline we are training - :return: directory label - """ - dirName = "estimator=" + str(ensembleIndex) - - return dirName - - -class bcolors: - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKCYAN = "\033[96m" - OKGREEN = "\033[92m" - WARNING = "\033[93m" - FAIL = "\033[91m" - ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - - -def resultsAnalysis(outDir): - """ - analyze the results of a bunch of parallel runs of the active learning pipeline - """ - outDicts = [] - os.chdir(outDir) - for dirs in os.listdir(outDir): - out = np.load(dirs + "/outputsDict.npy", allow_pickle=True).item() - outDicts.append(out) - - # collect info for plotting - numIter = out["params"]["pipeline iterations"] - numModels = out["params"]["model ensemble size"] - numSampler = out["params"]["num samplers"] - optima = [] - testLoss = [] - oracleOptima = [] - for dict in outDicts: - oracleOptima.append(np.amin(dict["oracle outputs"]["energy"])) - optima.append(np.amin(dict["best optima found"])) - testLoss.append(np.amin(dict["model test minima"])) - - # average over repeated runs - oracleOptima = np.asarray(oracleOptima) - optima = np.asarray(optima) - testLoss = np.asarray(testLoss) - - avgDiff = [] - avgLoss = [] - - for i in range(5): # - avgDiff.append( - np.average( - np.abs((oracleOptima[i:-1:5] - optima[i:-1:5]) / oracleOptima[i:-1:5]) - ) - ) - avgLoss.append(np.average(testLoss[i:-1:5])) - - plt.clf() - plt.plot(avgLoss / np.amax(avgLoss), label="test loss") - plt.plot(avgDiff / np.amax(avgDiff), label="pipeline error") - plt.legend() - - -# TODO: dict_size is unused -def binaryDistance(samples, dict_size=None, pairwise=False, extractInds=None): - """ - compute simple sum of distances between sample vectors: distance = disagreement of allele elements. - :param samples: - :return: - """ - # determine if all samples have equal length - """ - lens = np.array([i.shape[-1] for i in samples]) - if len(np.unique(lens)) > 1: # if there are multiple lengths, we need to pad up to a constant length - raise ValueError('Attempted to compute binary distances between samples with different lengths!') - if (len(samples) > 1e3) and (extractInds is None): # one-hot overhead is worth it for larger samples - distances = oneHotDistance(samples, dict_size, pairwise=pairwise, extractInds=extractInds) - elif (len(samples) > 1e3) and (extractInds > 10): # one-hot overhead is worth it for larger samples - distances = oneHotDistance(samples, dict_size, pairwise=pairwise, extractInds=extractInds) - else: - """ - - if extractInds is not None: - nOutputs = extractInds - else: - nOutputs = len(samples) - - if pairwise: # compute every pairwise distances - distances = np.zeros((nOutputs, nOutputs)) - for i in range(nOutputs): - distances[i, :] = np.sum(samples[i] != samples, axis=1) / len(samples[i]) - else: # compute average distance of each sample from all the others - distances = np.zeros(nOutputs) - if len(samples) == nOutputs: # compute distance with itself - for i in range(nOutputs): - distances[i] = np.sum(samples[i] != samples) / len(samples.flatten()) - # print('Compared with itelf.') - else: # compute distance from the training set or random set - references = samples[nOutputs:] - for i in range(nOutputs): - distances[i] = np.sum(samples[i] != references) / len( - references.flatten() - ) - # print('Compared with external reference.') - return distances - - -def oneHotDistance(samples, dict_size, pairwise=False, extractInds=None): - """ - find the minimum single mutation distance (normalized) between sequences - optionally explicitly extract only the first extractInds sequences distances, with respect to themselves and all others - :param samples: - :param pairwise: - :param extractInds: - :return: - """ - # do one-hot encoding - oneHot = np_oneHot( - samples, int(dict_size + 1) - ) # assumes dict is 1-N with 0 padding - oneHot = oneHot.reshape(oneHot.shape[0], int(oneHot.shape[1] * oneHot.shape[2])) - target = oneHot[ - :extractInds - ] # limit the number of samples we are actually interested in - if target.ndim == 1: - target = np.expand_dims(target, 0) - - dists = 1 - target @ oneHot.transpose() / samples.shape[1] - if pairwise: - return dists - else: - return np.average(dists, axis=1) - - -def np_oneHot(samples, uniques): - samples = samples.astype(int) - flatsamples = samples.flatten() - shape = (flatsamples.size, uniques) - one_hot = np.zeros(shape) - rows = np.arange(flatsamples.size) - one_hot[rows, flatsamples] = 1 - return one_hot.reshape(samples.shape[0], samples.shape[1], uniques) - - -def sortTopXSamples(sortedSamples, nSamples=int(1e6), distCutoff=0.2): - # collect top distinct samples - - bestSamples = np.expand_dims( - sortedSamples[0], 0 - ) # start with the best identified sequence - bestInds = [0] - i = -1 - while (len(bestInds) < nSamples) and (i < len(sortedSamples) - 1): - i += 1 - candidate = np.expand_dims(sortedSamples[i], 0) - sampleList = np.concatenate((bestSamples, candidate)) - - dists = binaryDistance(sampleList, pairwise=True)[ - -1, :-1 - ] # pairwise distances between candiate and prior samples - if all(dists > distCutoff): # if the samples are all distinct - bestSamples = np.concatenate((bestSamples, candidate)) - bestInds.append(i) - - return bestInds - - -def numpy_fillna(data): - # Get lengths of each row of data - lens = np.array([len(i) for i in data]) - - # Mask of valid places in each row - mask = np.arange(lens.max()) < lens[:, None] - - # Setup output array and put elements from data into masked positions - out = np.zeros(mask.shape, dtype=data.dtype) - out[mask] = np.concatenate(data) - return out - - -def filterDuplicateSamples(samples, oldDatasetPath=None, returnInds=False): - """ - assumes original dataset contains no duplicates - :param samples: must be np array padded to equal length. If a combination of new and original datasets, critical that the original data comes first. - : param origDatasetLen: if samples is a combination of new and old datasets, set old dataset first with length 'origDatasetLen' - :return: non-duplicate samples and/or indices of such samples - """ - origDatasetLen = 0 # if there is no old dataset, take everything - if oldDatasetPath is not None: - dataset = np.load(oldDatasetPath, allow_pickle=True).item()["samples"] - origDatasetLen = len(dataset) - samples = np.concatenate((dataset, samples), axis=0) - - samplesTuple = [tuple(row) for row in samples] - seen = set() - seen_add = seen.add - - filtered = [ - [samplesTuple[i], i] - for i in range(len(samplesTuple)) - if not (samplesTuple[i] in seen or seen_add(samplesTuple[i])) - ] - filteredSamples = [filtered[i][0] for i in range(len(filtered))][ - origDatasetLen: - ] # unique samples - filteredInds = [filtered[i][1] for i in range(len(filtered))][ - origDatasetLen: - ] # unique sample idxs - - assert ( - len(filteredSamples) > 0 - ), "Sampler returned duplicates only, problem may be completely solved, or sampler is too myopic" - - if returnInds: - return ( - np.asarray(filteredSamples), - np.asarray(filteredInds) - - origDatasetLen, # in samples basis (omitting any prior dataset) - ) - else: - return np.asarray(filteredSamples) - - -def generateRandomSamples( - nSamples, - sampleLengthRange, - dictSize, - oldDatasetPath=None, - variableLength=True, - seed=None, -): - """ - randomly generate a non-repeating set of samples of the appropriate size and composition - :param nSamples: - :param sampleLengthRange: - :param dictSize: - :param variableLength: - :return: - """ - if seed is not None: - np.random.seed(seed) - if variableLength: - samples = [] - while len(samples) < nSamples: - for i in range(sampleLengthRange[0], sampleLengthRange[1] + 1): - samples.extend( - np.random.randint(1, dictSize + 1, size=(int(10 * dictSize * i), i)) - ) - - samples = numpy_fillna( - np.asarray(samples, dtype=object) - ) # pad sequences up to maximum length - samples = filterDuplicateSamples( - samples, oldDatasetPath - ) # this will naturally proportionally punish shorter sequences - if len(samples) < nSamples: - samples = samples.tolist() - - else: # fixed sample size - samples = [] - while len(samples) < nSamples: - samples.extend( - np.random.randint( - 1, dictSize + 1, size=(2 * nSamples, sampleLengthRange[1]) - ) - ) - samples = numpy_fillna( - np.asarray(samples, dtype=object) - ) # pad sequences up to maximum length - samples = filterDuplicateSamples( - samples, oldDatasetPath - ) # this will naturally proportionally punish shorter sequences - if len(samples) < nSamples: - samples = samples.tolist() - - np.random.shuffle( - samples - ) # shuffle so that sequences with different lengths are randomly distributed - samples = samples[ - :nSamples - ] # after shuffle, reduce dataset to desired size, with properly weighted samples - - return samples - - -def get_n_params(model): - """ - count parameters for a pytorch model - :param model: - :return: - """ - pp = 0 - for p in list(model.parameters()): - nn = 1 - for s in list(p.size()): - nn = nn * s - pp += nn - return pp - - -def doAgglomerativeClustering(samples, energies, uncertainties, dict_size, cutoff=0.25): - """ - agglomerative clustering and sorting with pairwise binary distance metric - :param samples: - :param energies: - :param cutoff: - :return: - """ - agglomerate = cluster.AgglomerativeClustering( - n_clusters=None, - affinity="precomputed", - linkage="average", - compute_full_tree=True, - distance_threshold=cutoff, - ).fit(binaryDistance(samples, dict_size, pairwise=True)) - labels = agglomerate.labels_ - nClusters = agglomerate.n_clusters_ - clusters = [] - totInds = [] - clusterEns = [] - clusterVars = [] - for i in range(len(np.unique(labels))): - inds = np.where(labels == i)[0].astype(int) - totInds.extend(inds) - clusters.append([samples[j] for j in inds]) - clusterEns.append([energies[j] for j in inds]) - clusterVars.append([uncertainties[j] for j in inds]) - - return clusters, clusterEns, clusterVars - - -def filterOutputs(outputs, additionalEntries=None): - """ - run filtering on particular outputs dictionaries - """ - - if additionalEntries is not None: - extraSamples = additionalEntries["samples"] - extraScores = additionalEntries["scores"] - extraEnergies = additionalEntries["energies"] - extraUncertainties = additionalEntries["uncertainties"] - samples = np.concatenate((outputs["samples"], extraSamples)) - scores = np.concatenate((outputs["scores"], extraScores)) - energies = np.concatenate((outputs["energies"], extraEnergies)) - uncertainties = np.concatenate((outputs["uncertainties"], extraUncertainties)) - else: - samples = outputs["samples"] - scores = outputs["scores"] - energies = outputs["energies"] - uncertainties = outputs["uncertainties"] - - filteredSamples, filteredInds = filterDuplicateSamples(samples, returnInds=True) - - filteredOutputs = { - "samples": filteredSamples, - "scores": scores[filteredInds], - "energies": energies[filteredInds], - "uncertainties": uncertainties[filteredInds], - } - printRecord( - "Sampler outputs after filtering - best energy = {:.4f}".format( - np.amin(energies) - ) - ) - - return filteredOutputs - - -def clusterAnalysis(clusters, clusterEns, clusterVars): - """ - get the average and minimum energies and variances at these points - :param clusters: - :param clusterEns: - :param clusterVars: - :return: - """ - clusterSize = np.asarray([len(cluster) for cluster in clusters]) - avgClusterEns = np.asarray([np.average(cluster) for cluster in clusterEns]) - minClusterEns = np.asarray([np.amin(cluster) for cluster in clusterEns]) - avgClusterVars = np.asarray([np.average(cluster) for cluster in clusterVars]) - minClusterVars = np.asarray( - [clusterVars[i][np.argmin(clusterEns[i])] for i in range(len(clusterVars))] - ) - minClusterSamples = np.asarray( - [clusters[i][np.argmin(clusterEns[i])] for i in range(len(clusterEns))] - ) - - clusterOrder = np.argsort(minClusterEns) - clusterSize = clusterSize[clusterOrder] - avgClusterEns = avgClusterEns[clusterOrder] - minClusterEns = minClusterEns[clusterOrder] - avgClusterVars = avgClusterVars[clusterOrder] - minClusterVars = minClusterVars[clusterOrder] - minClusterSamples = minClusterSamples[clusterOrder] - - return ( - clusterSize, - avgClusterEns, - minClusterEns, - avgClusterVars, - minClusterVars, - minClusterSamples, - ) - - -class resultsPlotter: - def __init__(self): - self.i = 0 - self.j = 0 - - def process(self, directory): - # get simulation results - os.chdir(directory) - results = np.load("outputsDict.npy", allow_pickle=True).item() - - self.niters = len(results["state dict record"]) - self.nmodels = results["state dict record"][0]["n proxy models"] - - self.trueMin = np.amin(results["oracle outputs"]["energies"]) - self.trueMinSample = results["oracle outputs"]["samples"][ - np.argmin(results["oracle outputs"]["energies"]) - ] - - self.avgTestLoss = np.asarray( - [results["state dict record"][i]["test loss"] for i in range(self.niters)] - ) - self.testStd = np.asarray( - [results["state dict record"][i]["test std"] for i in range(self.niters)] - ) - self.allTestLosses = np.asarray( - [ - results["state dict record"][i]["all test losses"] - for i in range(self.niters) - ] - ) - self.stdEns = np.asarray( - [ - results["state dict record"][i]["best cluster energies"] - for i in range(self.niters) - ] - ) # these come standardized out of the box - self.stdDevs = np.asarray( - [ - results["state dict record"][i]["best cluster deviations"] - for i in range(self.niters) - ] - ) - self.stateSamples = np.asarray( - [ - results["state dict record"][i]["best cluster samples"] - for i in range(self.niters) - ] - ) - self.internalDists = np.asarray( - [ - results["state dict record"][i]["best clusters internal diff"] - for i in range(self.niters) - ] - ) - self.datasetDists = np.asarray( - [ - results["state dict record"][i]["best clusters dataset diff"] - for i in range(self.niters) - ] - ) - self.randomDists = np.asarray( - [ - results["state dict record"][i]["best clusters random set diff"] - for i in range(self.niters) - ] - ) - self.bigDataLoss = np.asarray( - [results["big dataset loss"][i] for i in range(self.niters)] - ) - self.bottom10Loss = np.asarray( - [results["bottom 10% loss"][i] for i in range(self.niters)] - ) - - # get dataset mean and std - target = os.listdir("datasets")[0] - dataset = np.load("datasets/" + target, allow_pickle=True).item() - datasetScores = dataset["scores"] - self.mean = np.mean(datasetScores) - self.std = np.sqrt(np.var(datasetScores)) - - # standardize results - self.stdTrueMin = (self.trueMin - self.mean) / self.std - - # normalize against true answer - self.normedEns = 1 - np.abs(self.stdTrueMin - self.stdEns) / np.abs( - self.stdTrueMin - ) - self.normedDevs = self.stdDevs / np.abs(self.stdTrueMin) - - self.xrange = ( - np.arange(self.niters) * results["config"].al.queries_per_iter - + results["config"].dataset.init_length - ) - - def averageResults(self, directories): - results = [] - for directory in directories: - self.process(directory) - results.append(self.__dict__) - - self.avgbigDataLoss = [] - self.avgbottom10Loss = [] - self.avgavgTestLoss = [] - self.avgtestStd = [] - self.avgstd = [] - self.avgnormedEns = [] - self.avgnormedDevs = [] - self.avginternalDists = [] - self.avgdatasetDists = [] - self.avgrandomDists = [] - for i in range(len(directories)): - self.avgbigDataLoss.append(results[i]["bigDataLoss"]) - self.avgbottom10Loss.append(results[i]["bottom10Loss"]) - self.avgavgTestLoss.append(results[i]["avgTestLoss"]) - self.avgtestStd.append(results[i]["testStd"]) - self.avgstd.append(results[i]["std"]) - self.avgnormedEns.append(results[i]["normedEns"]) - self.avgnormedDevs.append(results[i]["normedDevs"]) - self.avginternalDists.append(results[i]["internalDists"]) - self.avgdatasetDists.append(results[i]["datasetDists"]) - self.avgrandomDists.append(results[i]["randomDists"]) - - self.bigDataLoss = np.average(self.avgbigDataLoss, axis=0) - self.bottom10Loss = np.average(self.avgbottom10Loss, axis=0) - self.avgTestLoss = np.average(self.avgavgTestLoss, axis=0) - self.testStd = np.average(self.avgtestStd, axis=0) - self.std = np.average(self.avgstd, axis=0) - self.normedEns = np.average(self.avgnormedEns, axis=0) - self.normedDevs = np.average(self.avgnormedDevs, axis=0) - self.internalDists = np.average(self.avginternalDists, axis=0) - self.datasetDists = np.average(self.avgdatasetDists, axis=0) - self.randomDists = np.average(self.avgrandomDists, axis=0) - - def plotLosses(self, fignum=1, color="k", label=None): - plt.figure(fignum) - plt.semilogy( - self.xrange, - self.bigDataLoss, - color + ".-", - label=label + " big sample loss", - ) - plt.semilogy( - self.xrange, - self.bottom10Loss, - color + "o-", - label=label + " bottom 10% loss", - ) - plt.fill_between( - self.xrange, - self.avgTestLoss - self.testStd / 2, - self.avgTestLoss + self.testStd / 2, - alpha=0.2, - edgecolor=color, - facecolor=color, - label=label + " test losses", - ) - plt.xlabel("Training Set Size") - plt.ylabel("Smooth L1 Loss") - plt.legend() - - def plotPerformance(self, fignum=1, color="k", label=None, ind=1): - plt.figure(fignum) - plt.plot(self.xrange, self.normedEns[:, 0], color + ".-") - plt.fill_between( - self.xrange, - self.normedEns[:, 0] - self.normedDevs[:, 0] / 2, - self.normedEns[:, 0] + self.normedDevs[:, 0] / 2, - alpha=0.2, - edgecolor=color, - facecolor=color, - label=label + " best optimum + uncertainty", - ) - avgens = np.average(self.normedEns, axis=1) - plt.errorbar( - self.xrange + ind * 10, - avgens, - yerr=[avgens - self.normedEns[:, 0], avgens - self.normedEns[:, 1]], - fmt=color + ".", - ecolor=color, - elinewidth=3, - capsize=1.5, - alpha=0.2, - label=label + " state range", - ) - # for i in range(self.normedEns.shape[1]): - # plt.plot(self.xrange + self.i / 10, self.normedEns[:,i], color + '.') - plt.xlabel("Training Set Size") - plt.ylabel("Performance") - plt.ylim(0, 1) - plt.legend() - - def plotDiversity(self, fignum=1, subplot=1, nsubplots=1, color="k", label=None): - plt.figure(fignum) - square = int(np.ceil(np.sqrt(nsubplots))) - plt.subplot(square, square, subplot) - plt.fill_between( - self.xrange, - np.amin(self.internalDists, axis=1), - np.amax(self.internalDists, axis=1), - alpha=0.2, - hatch="o", - edgecolor=color, - facecolor=color, - label=label + " internal dist", - ) - plt.plot(self.xrange, np.average(self.internalDists, axis=1), color + "-") - plt.fill_between( - self.xrange, - np.amin(self.datasetDists, axis=1), - np.amax(self.datasetDists, axis=1), - alpha=0.2, - hatch="-", - edgecolor=color, - facecolor=color, - label=label + " dataset dist", - ) - plt.plot(self.xrange, np.average(self.datasetDists, axis=1), color + "-") - plt.fill_between( - self.xrange, - np.amin(self.randomDists, axis=1), - np.amax(self.randomDists, axis=1), - alpha=0.2, - hatch="/", - edgecolor=color, - facecolor=color, - label=label + " random dist", - ) - plt.plot(self.xrange, np.average(self.randomDists, axis=1), color + "-") - plt.xlabel("Training Set Size") - plt.ylabel("Binary Distances") - plt.legend() - - def plotDiversityProduct(self, fignum=1, color="k", label=None): - plt.figure(fignum) - divXEn = ( - self.internalDists * self.normedEns - ) # pointwise product of internal distance metric and normalized energy (higher is better) - plt.fill_between( - self.xrange, - np.amin(divXEn, axis=1), - np.amax(divXEn, axis=1), - alpha=0.2, - edgecolor=color, - facecolor=color, - label=label + " dist evolution", - ) - plt.xlabel("Training Set Size") - plt.ylabel("Energy x dist") - plt.legend() - - def plotDiversityMesh( - self, fignum=1, subplot=1, nsubplots=1, color="k", label=None - ): - plt.figure(fignum) - square = int(np.ceil(np.sqrt(nsubplots))) - plt.subplot(square, square, subplot) - flatDist = self.internalDists.flatten() - flatEns = self.normedEns.flatten() - ttime = np.zeros_like(self.internalDists) - for i in range(self.niters): - ttime[i] = i + 1 - flatTime = ttime.flatten() - plt.tricontourf(flatDist, flatEns, flatTime) - plt.title("Diversity and Energy over time") - plt.xlabel("Internal Distance") - plt.ylabel("Sample Energy") - plt.xlim(0, 1) - plt.ylim(0, 1) - plt.clim(1, self.niters) - plt.colorbar() - plt.tight_layout() - - -def dict2namespace(data_dict): - """ - Recursively converts a dictionary and its internal dictionaries into an - argparse.Namespace - - Parameters - ---------- - data_dict : dict - The input dictionary - - Return - ------ - data_namespace : argparse.Namespace - The output namespace - """ - for k, v in data_dict.items(): - if isinstance(v, dict): - data_dict[k] = dict2namespace(v) - else: - pass - data_namespace = Namespace(**data_dict) - - return data_namespace - - -def namespace2dict(data_namespace): - """ - Recursively converts a dictionary and its internal dictionaries into an - argparse.Namespace - - Parameters - ---------- - data_dict : dict - The input dictionary - - Return - ------ - data_namespace : argparse.Namespace - The output namespace - """ - data_dict = {} - for k in vars(data_namespace): - if isinstance(getattr(data_namespace, k), Namespace): - data_dict.update({k: namespace2dict(getattr(data_namespace, k))}) - else: - data_dict.update({k: getattr(data_namespace, k)}) - - return data_dict - - -def numpy2python(results_dict): - """ - Recursively converts the numpy types into native Python types in order to - enable proper dumping into YAML files: - - Parameters - ---------- - results_dict : dict - The input dictionary - - Return - ------ - results_dict : dict - The modified dictionary - """ - - def convert(v): - if isinstance(v, np.ndarray): - if np.ndim(v) == 1: - return v.tolist() - elif isinstance(v, (int, np.integer)): - return int(v) - elif isinstance(v, (float, np.float, np.float32)): - return float(v) - elif isinstance(v, list): - for idx, el in enumerate(v): - v[idx] = convert(el) - return v - elif isinstance(v, dict): - return numpy2python(v) - elif isinstance(v, Namespace): - return numpy2python(vars(v)) - else: - return v - - for k, v in results_dict.items(): - if isinstance(v, dict): - numpy2python(v) - elif isinstance(v, Namespace): - numpy2python(vars(v)) - else: - results_dict[k] = convert(v) - - return results_dict - - -def normalizeDistCutoff(cutoff): - return (1 + np.tanh(cutoff)) / 2 - - -def bracket_dot_to_num(sequences, maxlen): - """ - convert from (((...))) notation to 111222333 - """ - my_seq = np.zeros((len(sequences), maxlen)) - row = 0 - - for seq in sequences: - col = 0 - for na in seq: - if na == "(": - my_seq[row, col] = 1 - elif na == ".": - my_seq[row, col] = 2 - elif na == ")": - my_seq[row, col] = 3 - col += 1 - row += 1 - - return my_seq - - -def add_bool_arg(parser, name, default=False): - group = parser.add_mutually_exclusive_group(required=False) - group.add_argument("--" + name, dest=name, action="store_true") - group.add_argument("--no-" + name, dest=name, action="store_false") - parser.set_defaults(**{name: default}) - return parser - - -def handle_logdir(): - # TODO - just copy-pasted - if "logdir" in config and config.logdir is not None: - if not Path(config.logdir).exists() or config.overwrite_logdir: - Path(config.logdir).mkdir(parents=True, exist_ok=True) - with open(config.logdir + "/config.yml", "w") as f: - yaml.dump( - numpy2python(namespace2dict(config)), f, default_flow_style=False - ) - torch.set_num_threads(1) - main(config) - else: - print(f"logdir {config.logdir} already exists! - Ending run...") - else: - print(f"working directory not defined - Ending run...") diff --git a/gflownet/utils/oracle.py b/gflownet/utils/oracle.py deleted file mode 100644 index b3d1f39a4..000000000 --- a/gflownet/utils/oracle.py +++ /dev/null @@ -1,629 +0,0 @@ -"""import statements""" - -import sys - -from omegaconf import ListConfig -from potts_utils import load_potts_model, potts_energy -from seqfold import dg, fold -from utils import * - -try: # we don't always install these on every platform - from nupack import * -except: - print( - "COULD NOT IMPORT NUPACK ON THIS DEVICE - proceeding, but will crash with nupack oracle selected" - ) - pass -try: - from bbdob import DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, WModel - from bbdob.utils import idx2one_hot -except: - print( - "COULD NOT IMPORT BB-DOB ON THIS DEVICE - proceeding, but will crash with BB-DOB oracle selected" - ) - pass - - -""" -This script computes a binding score for a given sequence or set of sequences - -> Inputs: numpy integer arrays - different oracles with different requirements -> Outputs: oracle outputs - usually numbers - -config -'dataset seed' - self explanatory -'dict size' - number of possible states per sequence element - e.g., for ATGC 'dict size' = 4 -'variable sample length', 'min sample length', 'max sample length' - for determining the length and variability of sample sequences -'init dataset length' - number of samples for initial (random) dataset -'dataset' - name of dataset to be saved -""" - - -class Oracle: - def __init__( - self, - oracle, - seed=0, - seq_len=30, - dict_size=4, - min_len=30, - max_len=30, - variable_len=True, - init_len=0, - energy_weight=False, - nupack_target_motif="", - seed_toy=0, - ): - """ - initialize the oracle - :param config: - """ - self.seed = seed - self.seq_len = seq_len - self.dict_size = dict_size - self.min_len = min_len - self.max_len = max_len - self.init_len = init_len - self.variable_len = variable_len - self.oracle = oracle - self.energy_weight = energy_weight - self.nupack_target_motif = nupack_target_motif - self.seed_toy = seed_toy - - np.random.seed(self.seed_toy) - if not "nupack" in self.oracle: - self.initRands() # initialize random numbers for hand-made oracles - - def initRands(self): - """ - initialize random numbers for custom-made toy functions - :return: - """ - - # set these to be always positive to play nice with gFlowNet sampling - if True: # self.config.test_mode: - self.linFactors = -np.ones( - self.seq_len - ) # Uber-simple function, for testing purposes - actually nearly functionally identical to one-max, I believe - else: - self.linFactors = np.abs( - np.random.randn(self.seq_len) - ) # coefficients for linear toy energy - - hamiltonian = np.random.randn(self.seq_len, self.seq_len) # energy function - self.hamiltonian = ( - np.tril(hamiltonian) + np.tril(hamiltonian, -1).T - ) # random symmetric matrix - - pham = np.zeros((self.seq_len, self.seq_len, self.dict_size, self.dict_size)) - for i in range(pham.shape[0]): - for j in range(i, pham.shape[1]): - for k in range(pham.shape[2]): - for l in range(k, pham.shape[3]): - num = -np.random.uniform(0, 1) - pham[i, j, k, l] = num - pham[i, j, l, k] = num - pham[j, i, k, l] = num - pham[j, i, l, k] = num - self.pottsJ = ( - pham # multilevel spin Hamiltonian (Potts Hamiltonian) - coupling term - ) - self.pottsH = np.random.randn( - self.seq_len, self.dict_size - ) # Potts Hamiltonian - onsite term - - # W-model parameters - # first get the binary dimension size - aa = np.arange(self.dict_size) - if self.variable_len: - aa = np.clip(aa, 1, self.dict_size) # merge padding with class 1 - x0 = np.binary_repr(aa[-1]) - dimension = int(len(x0) * self.max_len) - - mu = np.random.randint(1, dimension + 1) - v = np.random.randint(1, dimension + 1) - m = np.random.randint(1, dimension) - n = np.random.randint(1, dimension) - gamma = np.random.randint(0, int(n * (n - 1) / 2)) - self.mu, self.v, self.m, self.n, self.gamma = [mu, v, m, n, gamma] - - def initializeDataset( - self, save=True, returnData=False, customSize=None, custom_seed=None - ): - """ - generate an initial toy dataset with a given number of samples - need an extra factor to speed it up (duplicate filtering is very slow) - :param numSamples: - :return: - """ - data = {} - if custom_seed: - np.random.seed(custom_seed) - else: - np.random.seed(self.seed) - if customSize is None: - datasetLength = self.init_len - else: - datasetLength = customSize - - if self.variable_len: - samples = [] - while len(samples) < datasetLength: - for i in range(self.min_len, self.max_len + 1): - samples.extend( - np.random.randint( - 0 + 1, - self.dict_size + 1, - size=(int(10 * self.dict_size * i), i), - ) - ) - - samples = self.numpy_fillna( - np.asarray(samples, dtype=object) - ) # pad sequences up to maximum length - samples = filterDuplicateSamples( - samples - ) # this will naturally proportionally punish shorter sequences - if len(samples) < datasetLength: - samples = samples.tolist() - np.random.shuffle( - samples - ) # shuffle so that sequences with different lengths are randomly distributed - samples = samples[ - :datasetLength - ] # after shuffle, reduce dataset to desired size, with properly weighted samples - else: # fixed sample size - samples = np.random.randint( - 1, self.dict_size + 1, size=(datasetLength, self.max_len) - ) - samples = filterDuplicateSamples(samples) - while len(samples) < datasetLength: - samples = np.concatenate( - ( - samples, - np.random.randint( - 1, self.dict_size + 1, size=(datasetLength, self.max_len) - ), - ), - 0, - ) - samples = filterDuplicateSamples(samples) - - data["samples"] = samples - data["energies"] = self.score(data["samples"]) - - if save: - np.save("datasets/" + self.oracle, data) - if returnData: - return data - - def score(self, queries): - """ - assign correct scores to selected sequences - :param queries: sequences to be scored - :return: computed scores - """ - if isinstance(queries, list): - queries = np.asarray(queries) # convert queries to array - block_size = int(1e4) # score in blocks of maximum 10000 - scores_list = [] - scores_dict = {} - for idx in range(len(queries) // block_size + bool(len(queries) % block_size)): - queryBlock = queries[idx * block_size : (idx + 1) * block_size] - scores_block = self.getScore(queryBlock) - if isinstance(scores_block, dict): - for k, v in scores_block.items(): - if k in scores_dict: - scores_dict[k].extend(list(v)) - else: - scores_dict.update({k: list(v)}) - else: - scores_list.extend(self.getScore(queryBlock)) - if len(scores_list) > 0: - return np.asarray(scores_list) - else: - return {k: np.asarray(v) for k, v in scores_dict.items()} - - def getScore(self, queries): - if self.oracle == "linear": - return self.linearToy(queries) - elif self.oracle == "potts": - return self.PottsEnergy(queries) - elif self.oracle == "potts new": - return self.PottsEnergyNew(queries) - elif self.oracle == "inner product": - return self.toyHamiltonian(queries) - elif self.oracle == "seqfold": - return self.seqfoldScore(queries) - elif self.oracle == "nupack energy": - return self.nupackScore(queries, returnFunc="energy") - elif self.oracle == "nupack pins": - return self.nupackScore( - queries, returnFunc="pins", energy_weighting=self.energy_weight - ) - elif self.oracle == "nupack pairs": - return self.nupackScore( - queries, returnFunc="pairs", energy_weighting=self.energy_weight - ) - elif self.oracle == "nupack open loop": - return self.nupackScore( - queries, returnFunc="open loop", energy_weighting=self.energy_weight - ) - elif self.oracle == "nupack motif": - return self.nupackScore( - queries, - returnFunc="motif", - motif=nupack_target_motif, - energy_weighting=self.energy_weight, - ) - - elif ( - (self.oracle == "onemax") - or (self.oracle == "twomin") - or (self.oracle == "fourpeaks") - or (self.oracle == "deceptivetrap") - or (self.oracle == "nklandscape") - or (self.oracle == "wmodel") - ): - return self.BB_DOB_functions(queries) - elif isinstance(self.oracle, (list, ListConfig)) and all( - ["nupack " in el for el in self.oracle] - ): - return self.nupackScore( - queries, returnFunc=[el.replace("nupack ", "") for el in self.oracle] - ) - elif ( - isinstance(self.oracle, (list, ListConfig)) - and self.oracle[0] == "potts new" - ): - return self.PottsEnergyNew(queries) - else: - raise NotImplementedError("Unknown oracle type") - - def BB_DOB_functions(self, queries): - """ - BB-DOB OneMax benchmark - :param queries: - :return: - """ - if self.variable_len: - queries = np.clip(queries, 1, self.dict_size) # merge padding with class 1 - - x0 = [ - np.binary_repr((queries[i][j] - 1).astype("uint8"), width=2) - for i in range(len(queries)) - for j in range(self.max_len) - ] # convert to binary - x0 = ( - np.asarray(x0).astype(str).reshape(len(queries), self.max_len) - ) # reshape to proper size - x0 = ["".join(x0[i]) for i in range(len(x0))] # concatenate to binary strings - x1 = np.zeros((len(queries), len(x0[0])), int) # initialize array - for i in range(len(x0)): # finally, as an array (took me long enough) - x1[i] = np.asarray(list(x0[i])).astype(int) - - dimension = x1.shape[1] - - x1 = idx2one_hot(x1, 2) # convert to BB_DOB one_hot format - - objective = self.getObjective(dimension) - - evals, info = objective(x1) - - return evals - - def getObjective(self, dimension): - if self.oracle == "onemax": # very limited in our DNA one-hot encoding - objective = OneMax(dimension) - elif self.oracle == "twomin": - objective = TwoMin(dimension) - elif self.oracle == "fourpeaks": # very limited in our DNA one-hot encoding - objective = FourPeaks(dimension, t=3) - elif self.oracle == "deceptivetrap": - objective = DeceptiveTrap(dimension, minimize=True) - elif self.oracle == "nklandscape": - objective = NKLandscape(dimension, minimize=True) - elif self.oracle == "wmodel": - objective = WModel( - dimension, - mu=self.mu, - v=self.v, - m=self.m, - n=self.n, - gamma=self.gamma, - minimize=True, - ) - else: - printRecord(self.oracle + " is not a valid dataset!") - sys.exit() - - return objective - - def linearToy(self, queries): - """ - return the energy of a toy model for the given set of queries - sites are completely uncorrelated - :param queries: - :return: - """ - energies = ( - queries @ self.linFactors - ) # simple matmul - padding entries (zeros) have zero contribution - - return energies - - def toyHamiltonian(self, queries): - """ - return the energy of a toy model for the given set of queries - sites may be correlated if they have a strong coupling (off diagonal term in the Hamiltonian) - :param queries: - :return: - """ - - energies = np.zeros(len(queries)) - for i in range(len(queries)): - energies[i] = ( - queries[i] @ self.hamiltonian @ queries[i].transpose() - ) # compute energy for each sample via inner product with the Hamiltonian - - return energies - - def PottsEnergy(self, queries): - """ - test oracle - randomly generated Potts Multilevel Spin Hamiltonian - each pair of sites is correlated depending on the occupation of each site - :param queries: sequences to be scored - :return: - """ - - # DNA Potts model - OLD - # coupling_dict = scipy.io.loadmat('40_level_scored.mat') - # N = coupling_dict['h'].shape[1] # length of DNA chain - # assert N == len(queries[0]), "Hamiltonian and proposed sequences are different sizes!" - # h = coupling_dict['h'] - # J = coupling_dict['J'] - - energies = np.zeros(len(queries)) - for k in range(len(queries)): - nnz = np.count_nonzero(queries[k]) - # potts hamiltonian - for ii in range(nnz): # ignore padding terms - energies[k] += self.pottsH[ - ii, queries[k, ii] - 1 - ] # add onsite term and account for indexing (e.g. 1-4 -> 0-3) - - for jj in range( - ii, nnz - ): # this is duplicated on lower triangle so we only need to do it from i-L - energies[k] += ( - 2 * self.pottsJ[ii, jj, queries[k, ii] - 1, queries[k, jj] - 1] - ) # site-specific couplings - - return energies - - def PottsEnergyNew(self, sequences): - # Load the potts model - J, h = load_potts_model(435) - - # Compute energies - energies = np.zeros(len(sequences)) - for idx, seq in enumerate(sequences): - energies[idx] = potts_energy(J, h, seq) - - return energies - - def seqfoldScore(self, queries, returnSS=False): - """ - get the secondary structure for a given sequence - using seqfold here - identical features are available using nupack, though results are sometimes different - :param sequence: - :return: - """ - temperature = 37.0 # celcius - sequences = self.numbers2letters(queries) - - energies = np.zeros(len(sequences)) - strings = [] - pairLists = [] - i = -1 - for sequence in sequences: - i += 1 - en = dg( - sequence, temp=temperature - ) # get predicted minimum energy of folded structure - if np.isfinite(en): - if ( - en > 1500 - ): # no idea why it does this but sometimes it adds 1600 - we will upgrade this to nupack in the future - energies[i] = en - 1600 - else: - energies[i] = en - else: - energies[i] = 5 # np.nan # set infinities as being very unlikely - - if returnSS: - structs = fold(sequence) # identify structural features - # print(round(sum(s.e for s in structs), 2)) # predicted energy of the final structure - - desc = ["."] * len(sequence) - pairList = [] - for s in structs: - pairList.append(s.ij[0]) - if len(s.ij) == 1: - i, j = s.ij[0] - desc[i] = "(" - desc[j] = ")" - - ssString = "".join(desc) # secondary structure string - strings.append(ssString) - pairList = np.asarray(pairList) + 1 # list of paired bases - pairLists.append(pairList) - - if returnSS: - return energies, strings, pairLists - else: - return energies - - def numbers2letters( - self, sequences - ): # Tranforming letters to numbers (1234 --> ATGC) - """ - Converts numerical values to ATGC-format - :param sequences: numerical DNA sequences to be converted - :return: DNA sequences in ATGC format - """ - if type(sequences) != np.ndarray: - sequences = np.asarray(sequences) - - my_seq = ["" for x in range(len(sequences))] - row = 0 - for j in range(len(sequences)): - seq = sequences[j, :] - assert ( - type(seq) != str - ), "Function inputs must be a list of equal length strings" - for i in range(len(sequences[0])): - na = seq[i] - if na == 1: - my_seq[row] += "A" - elif na == 2: - my_seq[row] += "T" - elif na == 3: - my_seq[row] += "C" - elif na == 4: - my_seq[row] += "G" - row += 1 - return my_seq - - def numpy_fillna(self, data): - """ - function to pad uneven-length vectors up to the max with zeros - :param data: - :return: - """ - # Get lengths of each row of data - lens = np.array([len(i) for i in data]) - - # Mask of valid places in each row - mask = np.arange(lens.max()) < lens[:, None] - - # Setup output array and put elements from data into masked positions - out = np.zeros(mask.shape, dtype=object) - out[mask] = np.concatenate(data) - return out - - def nupackScore( - self, queries, returnFunc="energy", energy_weighting=False, motif=None - ): - # Nupack requires Linux OS. - # use nupack instead of seqfold - more stable and higher quality predictions in general - # returns the energy of the most probable structure only - #:param queries: - #:param returnFunct 'energy' 'pins' 'pairs' - #:return: - - temperature = 310.0 # Kelvin - ionicStrength = 1.0 # molar - if not isinstance(queries[0], str): - sequences = self.numbers2letters(queries) - else: - sequences = queries - - energies = np.zeros(len(sequences)) - nPins = np.zeros(len(sequences)).astype(int) - nPairs = 0 - ssStrings = np.zeros(len(sequences), dtype=object) - - # parallel evaluation - fast - strandList = [] - comps = [] - i = -1 - for sequence in sequences: - i += 1 - strandList.append(Strand(sequence, name="strand{}".format(i))) - comps.append(Complex([strandList[-1]], name="comp{}".format(i))) - - set = ComplexSet( - strands=strandList, complexes=SetSpec(max_size=1, include=comps) - ) - model1 = Model(material="dna", celsius=temperature - 273, sodium=ionicStrength) - results = complex_analysis(set, model=model1, compute=["mfe"]) - for i in range(len(energies)): - energies[i] = results[comps[i]].mfe[0].energy - ssStrings[i] = str(results[comps[i]].mfe[0].structure) - - dict_return = {} - if "pins" in returnFunc: - for i in range(len(ssStrings)): - indA = 0 # hairpin completion index - for j in range(len(sequences[i])): - if ssStrings[i][j] == "(": - indA += 1 - elif ssStrings[i][j] == ")": - indA -= 1 - if indA == 0: # if we come to the end of a distinct hairpin - nPins[i] += 1 - dict_return.update({"pins": -nPins}) - if "pairs" in returnFunc: - nPairs = np.asarray([ssString.count("(") for ssString in ssStrings]).astype( - int - ) - dict_return.update({"pairs": -nPairs}) - if "energy" in returnFunc: - dict_return.update( - {"energy": energies} - ) # this is already negative by construction in nupack - - if "open loop" in returnFunc: - biggest_loop = np.zeros(len(ssStrings)) - for i in range( - len(ssStrings) - ): # measure all the open loops and return the largest - loops = [0] # size of loops - counting = 0 - indA = 0 - # loop completion index - for j in range(len(sequences[i])): - if ssStrings[i][j] == "(": - counting = 1 - indA = 0 - if (ssStrings[i][j] == ".") and (counting == 1): - indA += 1 - if (ssStrings[i][j] == ")") and (counting == 1): - loops.append(indA) - counting = 0 - biggest_loop[i] = max(loops) - dict_return.update({"open loop": -biggest_loop}) - - if ( - "motif" in returnFunc - ): # searches for a particular fold NOTE searches for this exact aptamer, not subsections or longer sequences with this as just one portion - #'((((....))))((((....))))....(((....)))' - # pad strings up to max length for binary distance calculation - padded_strings = bracket_dot_to_num(ssStrings, maxlen=self.max_len) - padded_motif = np.expand_dims( - bracket_dot_to_num([motif, motif], maxlen=self.max_len)[0], 0 - ) - motif_distance = binaryDistance( - np.concatenate((padded_motif, padded_strings), axis=0), pairwise=True - )[ - 0, 1: - ] # the first element is the motif we are looking for - take everything after this - dict_return.update( - {"motif": motif_distance - 1} - ) # result is normed on 0-1, so dist-1 gives scaling from 0(bad) to -1(good) - - if energy_weighting: - for key in dict_return.keys(): - if key != "energy": - dict_return[key] = dict_return[key] * np.tanh( - np.abs(energies) / 2 - ) # positive tahn of the energies, scaled - - if isinstance(returnFunc, list): - if len(returnFunc) > 1: - return dict_return - else: - return dict_return[returnFunc[0]] - else: - return dict_return[returnFunc] diff --git a/main.py b/main.py index 94425e0a2..e9d6401b3 100644 --- a/main.py +++ b/main.py @@ -31,7 +31,7 @@ def main(config): # Logger logger = hydra.utils.instantiate(config.logger, config, _recursive_=False) - # The proxy is required in the env for scoring: might be an oracle or a model + # The proxy is required in the env for scoring proxy = hydra.utils.instantiate( config.proxy, device=config.device, diff --git a/scripts/oracle_annotate.py b/scripts/oracle_annotate.py deleted file mode 100644 index 2236eb36f..000000000 --- a/scripts/oracle_annotate.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Annotates a data set with an oracle -""" -import hydra -import pandas as pd -from omegaconf import DictConfig, ListConfig, OmegaConf -from oracle import Oracle - - -@hydra.main(version_base=None, config_path="config", config_name="default") -def main(cfg: DictConfig): - # Make cfg root the specific config of this script - cfg = cfg.oracle_annotate - print(OmegaConf.to_yaml(cfg)) - if cfg.env_id == "aptamers": - oracle = Oracle( - oracle=cfg.oracle, - ) - elif cfg.env_id == "grid": - oracle = None - else: - raise NotImplementedError - # Data set - df_input = pd.read_csv(cfg.input_csv, index_col=0) - # Query oracles - energies = oracle.score(df_input.samples.values) - # Build output CSV - if isinstance(cfg.oracle, (list, ListConfig)): - oracles = [el.replace("nupack ", "") for el in cfg.oracle] - if cfg.output_csv: - if isinstance(energies, dict): - energies.update( - {"samples": df_input.samples.values, "energies": energies[oracles[0]]} - ) - df = pd.DataFrame(energies) - else: - df = pd.DataFrame( - {"samples": df_input.samples.values, "energies": energies} - ) - df.to_csv(cfg.output_csv) - - -if __name__ == "__main__": - main() diff --git a/scripts/oracle_sampler.py b/scripts/oracle_sampler.py deleted file mode 100644 index 15c59bfe7..000000000 --- a/scripts/oracle_sampler.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -Script to create data set of with nupack labels. -""" - -import os -import pickle -import time -from argparse import ArgumentParser -from pathlib import Path - -import numpy as np -import pandas as pd -import yaml -from oracle import Oracle -from tqdm import tqdm -from utils import get_config, namespace2dict, numpy2python - - -def add_args(parser): - """ - Adds command-line arguments to parser - - Returns - ------- - argparse.ArgumentParser - The parser with added arguments - """ - args2config = {} - parser.add_argument( - "-y", - "--yaml_config", - default=None, - type=str, - help="YAML configuration file", - ) - args2config.update({"yaml_config": ["yaml_config"]}) - parser.add_argument( - "--seed_toy", - type=int, - default=0, - ) - args2config.update({"seed_toy": ["seeds", "toy_oracle"]}) - parser.add_argument( - "--seed_dataset", - type=int, - default=0, - ) - args2config.update({"seed_dataset": ["seeds", "dataset"]}) - parser.add_argument( - "--oracle", - nargs="+", - default="nupack energy", - help="linear, potts, nupack energy, nupack pairs, nupack pins", - ) - args2config.update({"oracle": ["dataset", "oracle"]}) - parser.add_argument( - "--nalphabet", - type=int, - default=4, - help="Alphabet size", - ) - args2config.update({"nalphabet": ["dataset", "dict_size"]}) - parser.add_argument( - "--fixed_length", - dest="variable_length", - action="store_false", - default=True, - help="Models will sample within ranges set below", - ) - args2config.update({"variable_length": ["dataset", "variable_length"]}) - parser.add_argument("--min_length", type=int, default=10) - args2config.update({"min_length": ["dataset", "min_length"]}) - parser.add_argument("--max_length", type=int, default=40) - args2config.update({"max_length": ["dataset", "max_length"]}) - parser.add_argument( - "--nsamples", - type=int, - default=int(1e2), - help="Number of samples", - ) - args2config.update({"nsamples": ["dataset", "init_length"]}) - parser.add_argument( - "--no_indices", - dest="no_indices", - action="store_true", - default=False, - help="Omit indices in output CSV", - ) - args2config.update({"no_indices": ["no_indices"]}) - parser.add_argument( - "--output_csv", - type=str, - default=None, - help="Output CSV", - ) - args2config.update({"output_csv": ["output"]}) - return parser, args2config - - -def main(args): - oracle = Oracle( - seed=args.seeds.dataset, - seq_len=args.dataset.max_length, - dict_size=args.dataset.dict_size, - min_len=args.dataset.min_length, - max_len=args.dataset.max_length, - oracle=args.dataset.oracle, - variable_len=args.dataset.variable_length, - init_len=args.dataset.init_length, - seed_toy=args.seeds.toy_oracle, - ) - samples_dict = oracle.initializeDataset(save=False, returnData=True) - energies = samples_dict["energies"] - samples_mat = samples_dict["samples"] - seq_letters = oracle.numbers2letters(samples_mat) - seq_ints = ["".join([str(el) for el in seq if el > 0]) for seq in samples_mat] - if isinstance(energies, dict): - energies.update({"samples": seq_letters, "indices": seq_ints}) - df = pd.DataFrame(energies) - else: - df = pd.DataFrame( - {"samples": seq_letters, "indices": seq_ints, "energies": energies} - ) - if args.output: - output_yml = Path(args.output).with_suffix(".yml") - with open(output_yml, "w") as f: - yaml.dump(numpy2python(namespace2dict(args)), f, default_flow_style=False) - if args.no_indices: - df.drop(columns="indices", inplace=True) - df.to_csv(args.output) - - -if __name__ == "__main__": - parser = ArgumentParser() - _, override_args = parser.parse_known_args() - parser, args2config = add_args(parser) - args = parser.parse_args() - config = get_config(args, override_args, args2config) - print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(config).items()])) - main(config) From e740e6256be42826f8f2617201c1ce3fa264be37 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:53:54 -0500 Subject: [PATCH 049/106] refactor `active_learning` to `use_context` --- config/gflownet/gflownet.yaml | 2 +- gflownet/gflownet.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config/gflownet/gflownet.yaml b/config/gflownet/gflownet.yaml index d0acec1b1..bc74c38cf 100644 --- a/config/gflownet/gflownet.yaml +++ b/config/gflownet/gflownet.yaml @@ -53,4 +53,4 @@ replay_sampling: permutation train_sampling: permutation num_empirical_loss: 200000 sample_only: False -active_learning: False +use_context: False diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 6cb934211..4399eb077 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -49,7 +49,7 @@ def __init__( num_empirical_loss, eval_config, state_flow=None, - active_learning=False, + use_context=False, sample_only=False, replay_sampling="permutation", train_sampling="permutation", @@ -101,9 +101,9 @@ def __init__( state_flow : dict, optional State flow config dictionary. See `gflownet.yaml:state_flow` for details. By default None. - active_learning : bool, optional - Whether this GFlowNetAgent is part of an active learning loop, by default - False. This means the logger will use its context in metrics names. + use_context : bool, optional + Whether the logger will use its context in metrics names. Formerly the + `active_learning: bool` flag. By default False. sample_only : bool, optional This GFNA is only going to be used to sample, no need to make the train/test buffer. @@ -247,7 +247,7 @@ def __init__( self.tau = optimizer.bootstrap_tau self.ema_alpha = optimizer.ema_alpha self.early_stopping = optimizer.early_stopping - self.use_context = active_learning + self.use_context = use_context self.logsoftmax = torch.nn.LogSoftmax(dim=1) # Training self.mask_invalid_actions = mask_invalid_actions From bb0486ca40c97ebeea65de90ecce97eba253da2d Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 18:56:04 -0500 Subject: [PATCH 050/106] Remove `sample_only` gflownet arg (and config) and `make_train_test` in `Buffer.__init__()` --- config/gflownet/gflownet.yaml | 1 - gflownet/gflownet.py | 8 +------- gflownet/utils/buffer.py | 1 - 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/config/gflownet/gflownet.yaml b/config/gflownet/gflownet.yaml index bc74c38cf..82e3f0981 100644 --- a/config/gflownet/gflownet.yaml +++ b/config/gflownet/gflownet.yaml @@ -52,5 +52,4 @@ replay_sampling: permutation # Train data set backward sampling train_sampling: permutation num_empirical_loss: 200000 -sample_only: False use_context: False diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 4399eb077..9574becce 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -50,7 +50,6 @@ def __init__( eval_config, state_flow=None, use_context=False, - sample_only=False, replay_sampling="permutation", train_sampling="permutation", **kwargs, @@ -104,9 +103,6 @@ def __init__( use_context : bool, optional Whether the logger will use its context in metrics names. Formerly the `active_learning: bool` flag. By default False. - sample_only : bool, optional - This GFNA is only going to be used to sample, no need to make the train/test - buffer. replay_sampling : str, optional Type of sampling for the replay buffer. See :method:`~gflownet.utils.buffer.select`. By default "permutation". @@ -161,9 +157,7 @@ def __init__( # Buffers self.replay_sampling = replay_sampling self.train_sampling = train_sampling - self.buffer = Buffer( - **buffer, env=self.env, make_train_test=not sample_only, logger=logger - ) + self.buffer = Buffer(**buffer, env=self.env, logger=logger) # Train set statistics and reward normalization constant if self.buffer.train is not None: energies_stats_tr = [ diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index fbdd6252d..d2ca3684c 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -19,7 +19,6 @@ class Buffer: def __init__( self, env, - make_train_test=False, replay_capacity=0, output_csv=None, data_path=None, From 0edbc6bb567a45fee5c293929473571aa3725ee1 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 19:06:08 -0500 Subject: [PATCH 051/106] revert standardize `main` with `gflownet_from_config` --- gflownet/utils/common.py | 23 ++++++++++++++++++++++- main.py | 5 +++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 478c91cf8..094d8aea2 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -239,12 +239,14 @@ def gflownet_from_config(config): """ # Logger logger = instantiate(config.logger, config, _recursive_=False) + # The proxy is required in the env for scoring proxy = instantiate( config.proxy, device=config.device, float_precision=config.float_precision, ) + # The proxy is passed to env and used for computing rewards env = instantiate( config.env, @@ -252,8 +254,11 @@ def gflownet_from_config(config): device=config.device, float_precision=config.float_precision, ) + + # The policy is used to model the probability of a forward/backward action forward_config = parse_policy_config(config, kind="forward") backward_config = parse_policy_config(config, kind="backward") + forward_policy = instantiate( forward_config, env=env, @@ -267,17 +272,33 @@ def gflownet_from_config(config): float_precision=config.float_precision, base=forward_policy, ) + + # State flow + if config.gflownet.state_flow is not None: + state_flow = instantiate( + config.gflownet.state_flow, + env=env, + device=config.device, + float_precision=config.float_precision, + base=forward_policy, + ) + else: + state_flow = None + + # GFlowNet Agent gflownet = instantiate( config.gflownet, device=config.device, float_precision=config.float_precision, env=env, - buffer=config.env.buffer, forward_policy=forward_policy, backward_policy=backward_policy, + state_flow=state_flow, + buffer=config.env.buffer, logger=logger, eval_config=config.eval, ) + return gflownet diff --git a/main.py b/main.py index e9d6401b3..fdfff3d43 100644 --- a/main.py +++ b/main.py @@ -29,6 +29,7 @@ def main(config): # Set other random seeds set_seeds(config.seed) + # Initialize GFlowNet from config # Logger logger = hydra.utils.instantiate(config.logger, config, _recursive_=False) # The proxy is required in the env for scoring @@ -93,11 +94,11 @@ def main(config): if config.n_samples > 0 and config.n_samples <= 1e5: batch, times = gflownet.sample_batch(n_forward=config.n_samples, train=False) x_sampled = batch.get_terminating_states(proxy=True) - energies = env.proxy(x_sampled) + energies = gflownet.env.proxy(x_sampled) x_sampled = batch.get_terminating_states() df = pd.DataFrame( { - "readable": [env.state2readable(x) for x in x_sampled], + "readable": [gflownet.env.state2readable(x) for x in x_sampled], "energies": energies.tolist(), } ) From db4f1dcf7615f192821f2fa46ffd688ec58463e4 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 20 Feb 2024 19:15:14 -0500 Subject: [PATCH 052/106] trailing breakpoint --- tests/gflownet/eval/test_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/gflownet/eval/test_base.py b/tests/gflownet/eval/test_base.py index a62c3162d..c3feed63c 100644 --- a/tests/gflownet/eval/test_base.py +++ b/tests/gflownet/eval/test_base.py @@ -46,7 +46,7 @@ def dummy_evaluator(config_for_tests): @pytest.fixture -def constant_evaluator(): +def constant_evaluator(): # faster fixture for state-less tests CONSTANT_EVALUATOR.config = OmegaConf.create( {"eval_config": {"metrics": "all"}, "logger": {}} ) @@ -234,7 +234,6 @@ def test__eval(gflownet_for_tests, parameterization): assert Path("./replay.pkl").exists() # results: {"metrics": dict[str, float], "figs": list[plt.Figure]} results = gflownet_for_tests.evaluator.eval() - breakpoint() for k, v in results["metrics"].items(): assert isinstance(k, str) From 829cf4ab842745150e54f225e28a677cfe318375 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 21 Feb 2024 12:59:39 -0500 Subject: [PATCH 053/106] Update docstring --- gflownet/evaluator/base.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index c2bd681b7..ae7d44a38 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -636,10 +636,9 @@ def eval(self, metrics=None, **plot_kwargs): Returns ------- - list - List of computed metrics and figures: [l1, kl, jsd, corr_prob_traj_rewards, - var_logrewards_logp, nll_tt, mean_logprobs_std, mean_probs_std, - logprobs_std_nll_ratio, figs] (should be refactored to dict) #TODO fix docstring + dict + Computed dict of metrics and figures as + `{"metrics": {str: float}, "figs": {str: plt.Figure}}`. """ gfn = self.gfn_agent metrics = self.make_metrics(metrics) @@ -834,13 +833,23 @@ def eval_and_log_top_k(self, it): if __name__ == "__main__": - # dev test case, will move to tests + # Try using the GFlowNetEvaluator by running this script from the root: + # $ ipython + # In [1]: run gflownet/eval/base.py + from pathlib import Path - scratch = Path(os.environ["SCRATCH"]) - run_dirs = scratch / "crystals/logs/icml24/crystalgfn" - gfn_run_dir = run_dirs / "4074836/2024-01-27_20-54-55/5908fe41" + # Demo run: + # $ python main.py user=$USER \ + # +experiments=simple_tetris \ + # logger.do.online=False \ + # eval.checkpoints_period=100 + scratch = Path("/network/scratch/s/schmidtv") + run_dirs = scratch / "crystals/logs/" + gfn_run_dir = run_dirs / "2024-02-20_19-31-50/0df3449a" # simple_tetris run gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) - gfne.plot() - gfne.compute_metrics() + results = gfne.eval() + for name, metric in results["metrics"].items(): + print(f"{name:20}: {metric:.4f}") + print("Available figures in results['figs']:", ", ".join(results["figs"].keys())) From 4347939602b88948df6680b368db1d9f300e4bd2 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 21 Feb 2024 13:04:08 -0500 Subject: [PATCH 054/106] remove unused --- gflownet/evaluator/base.py | 1 - gflownet/gflownet.py | 1 - 2 files changed, 2 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index ae7d44a38..cae5dd888 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -513,7 +513,6 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): def compute_density_metrics(self, x_tt, dict_tt, metrics=None): gfn = self.gfn_agent metrics = self.make_metrics(metrics) - reqs = self.make_requirements(metrics=metrics) # TODO-V: unused for now, TBD density_metrics = {} x_sampled = density_true = density_pred = None diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 9574becce..8c7f3dbad 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1103,7 +1103,6 @@ def train(self): # Log if self.logger.lightweight: all_losses = all_losses[-100:] - all_visited = states_term # TODO-V: unused else: all_visited.extend(states_term) # Progress bar From 25cf720dae4e1f47e588c7825e28affc6e11530d Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 21 Feb 2024 14:19:08 -0500 Subject: [PATCH 055/106] improve example --- gflownet/evaluator/base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index cae5dd888..2e524718c 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -834,7 +834,7 @@ def eval_and_log_top_k(self, it): if __name__ == "__main__": # Try using the GFlowNetEvaluator by running this script from the root: # $ ipython - # In [1]: run gflownet/eval/base.py + # In [1]: run gflownet/evaluator/base.py from pathlib import Path @@ -845,10 +845,14 @@ def eval_and_log_top_k(self, it): # eval.checkpoints_period=100 scratch = Path("/network/scratch/s/schmidtv") run_dirs = scratch / "crystals/logs/" - gfn_run_dir = run_dirs / "2024-02-20_19-31-50/0df3449a" # simple_tetris run + gfn_run_dir = run_dirs / "2024-02-20_19-31-50" # simple_tetris run gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) results = gfne.eval() for name, metric in results["metrics"].items(): print(f"{name:20}: {metric:.4f}") - print("Available figures in results['figs']:", ", ".join(results["figs"].keys())) + print( + "Available figures in results['figs']:", + ", ".join([fname for fname, fig in results["figs"].items() if fig is not None]) + or "None", + ) From 9f50b45c646dc281c71a7db74d80368d9cb4ee1f Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 21 Feb 2024 14:19:28 -0500 Subject: [PATCH 056/106] move `from_agent` and `from_dir` methods --- gflownet/evaluator/base.py | 134 ++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 2e524718c..95f6ae53c 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -93,6 +93,73 @@ def __init__(self, **kwargs): self.metrics = self.make_metrics(self.config.metrics) self.reqs = self.make_requirements() + @classmethod + def from_dir( + cls: "GFlowNetEvaluator", + path: Union[str, os.PathLike], + no_wandb: bool = True, + print_config: bool = False, + device: str = "cuda", + load_final_ckpt: bool = True, + ): + """ + Instantiate a GFlowNetEvaluator from a run directory. + + Parameters + ---------- + cls : GFlowNetEvaluator + Class to instantiate. + path : Union[str, os.PathLike] + Path to the run directory from which to load the GFlowNetAgent. + no_wandb : bool, optional + Prevent wandb initialization, by default True + print_config : bool, optional + Whether or not to print the resulting (loaded) config, by default False + device : str, optional + Device to use for the instantiated GFlowNetAgent, by default "cuda" + load_final_ckpt : bool, optional + Use the latest possible checkpoint available in the path, by default True + + Returns + ------- + GFlowNetEvaluator + Instance of GFlowNetEvaluator with the GFlowNetAgent loaded from the run. + """ + gfn_agent, _ = load_gflow_net_from_run_path( + path, + no_wandb=no_wandb, + print_config=print_config, + device=device, + load_final_ckpt=load_final_ckpt, + ) + return GFlowNetEvaluator.from_agent(gfn_agent) + + @classmethod + def from_agent(cls, gfn_agent): + """ + Instantiate a GFlowNetEvaluator from a GFlowNetAgent. + + Parameters + ---------- + cls : GFlowNetEvaluator + Evaluator class to instantiate. + gfn_agent : GFlowNetAgent + Instance of GFlowNetAgent to use for the GFlowNetEvaluator. + + Returns + ------- + GFlowNetEvaluator + Instance of GFlowNetEvaluator with the provided GFlowNetAgent. + """ + from gflownet.gflownet import GFlowNetAgent + + assert isinstance(gfn_agent, GFlowNetAgent), ( + "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " + + f"{type(gfn_agent)}." + ) + + return GFlowNetEvaluator(gfn_agent=gfn_agent, sentinel=_sentinel) + def make_metrics(self, metrics=None): """ Parse metrics from a dict, list, a string or None. @@ -342,73 +409,6 @@ def should_checkpoint(self, step): else: return not step % self.config.checkpoints_period - @classmethod - def from_dir( - cls: "GFlowNetEvaluator", - path: Union[str, os.PathLike], - no_wandb: bool = True, - print_config: bool = False, - device: str = "cuda", - load_final_ckpt: bool = True, - ): - """ - Instantiate a GFlowNetEvaluator from a run directory. - - Parameters - ---------- - cls : GFlowNetEvaluator - Class to instantiate. - path : Union[str, os.PathLike] - Path to the run directory from which to load the GFlowNetAgent. - no_wandb : bool, optional - Prevent wandb initialization, by default True - print_config : bool, optional - Whether or not to print the resulting (loaded) config, by default False - device : str, optional - Device to use for the instantiated GFlowNetAgent, by default "cuda" - load_final_ckpt : bool, optional - Use the latest possible checkpoint available in the path, by default True - - Returns - ------- - GFlowNetEvaluator - Instance of GFlowNetEvaluator with the GFlowNetAgent loaded from the run. - """ - gfn_agent, _ = load_gflow_net_from_run_path( - path, - no_wandb=no_wandb, - print_config=print_config, - device=device, - load_final_ckpt=load_final_ckpt, - ) - return GFlowNetEvaluator.from_agent(gfn_agent) - - @classmethod - def from_agent(cls, gfn_agent): - """ - Instantiate a GFlowNetEvaluator from a GFlowNetAgent. - - Parameters - ---------- - cls : GFlowNetEvaluator - Evaluator class to instantiate. - gfn_agent : GFlowNetAgent - Instance of GFlowNetAgent to use for the GFlowNetEvaluator. - - Returns - ------- - GFlowNetEvaluator - Instance of GFlowNetEvaluator with the provided GFlowNetAgent. - """ - from gflownet.gflownet import GFlowNetAgent - - assert isinstance(gfn_agent, GFlowNetAgent), ( - "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " - + f"{type(gfn_agent)}." - ) - - return GFlowNetEvaluator(gfn_agent=gfn_agent, sentinel=_sentinel) - def plot(self, x_sampled=None, kde_pred=None, kde_true=None, **plot_kwargs): """ Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which From 3c5186b664613c54a8c9c300efbf045a6641eeff Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 21 Feb 2024 14:19:54 -0500 Subject: [PATCH 057/106] use `gflownet_from_config` in `load_gflow_net_from_run_path` --- gflownet/utils/common.py | 70 ++++++++++++---------------------------- 1 file changed, 21 insertions(+), 49 deletions(-) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 094d8aea2..531cae9c6 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -211,6 +211,8 @@ def find_latest_checkpoint(ckpt_dir, ckpt_name): If no final checkpoint is found and no other checkpoints are found according to the specified pattern: `{ckpt_name}*`. """ + if ckpt_name is None: + return None ckpt_name = Path(ckpt_name).stem final = list(ckpt_dir.glob(f"{ckpt_name}*final*")) if len(final) > 0: @@ -345,46 +347,7 @@ def load_gflow_net_from_run_path( # Disable wandb config.logger.do.online = False - # Logger - logger = instantiate(config.logger, config, _recursive_=False) - # The proxy is required in the env for scoring - proxy = instantiate( - config.proxy, - device=config.device, - float_precision=config.float_precision, - ) - # The proxy is passed to env and used for computing rewards - env = instantiate( - config.env, - proxy=proxy, - device=config.device, - float_precision=config.float_precision, - ) - forward_config = parse_policy_config(config, kind="forward") - backward_config = parse_policy_config(config, kind="backward") - forward_policy = instantiate( - forward_config, - env=env, - device=config.device, - float_precision=config.float_precision, - ) - backward_policy = instantiate( - backward_config, - env=env, - device=config.device, - float_precision=config.float_precision, - base=forward_policy, - ) - gflownet = instantiate( - config.gflownet, - device=config.device, - float_precision=config.float_precision, - env=env, - buffer=config.env.buffer, - forward_policy=forward_policy, - backward_policy=backward_policy, - logger=logger, - ) + gflownet = gflownet_from_config(config) if not load_final_ckpt: return gflownet, config @@ -393,18 +356,27 @@ def load_gflow_net_from_run_path( # ----- Load final models ----- # ------------------------------- - ckpt = [f for f in run_path.rglob(config.logger.logdir.ckpts) if f.is_dir()][0] - forward_final = find_latest_checkpoint(ckpt, config.policy.forward.checkpoint) - gflownet.forward_policy.model.load_state_dict( - torch.load(forward_final, map_location=set_device(device)) - ) - try: - backward_final = find_latest_checkpoint(ckpt, config.policy.backward.checkpoint) + ckpt_dir = [f for f in run_path.rglob(config.logger.logdir.ckpts) if f.is_dir()][0] + + forward_final = find_latest_checkpoint(ckpt_dir, config.policy.forward.checkpoint) + if forward_final is None: + print("Warning: no forward policy checkpoint found") + else: + gflownet.forward_policy.model.load_state_dict( + torch.load(forward_final, map_location=set_device(device)) + ) + + backward_final = find_latest_checkpoint(ckpt_dir, config.policy.backward.checkpoint) + if backward_final is None: + print("Warning: no backward policy checkpoint found") + else: gflownet.backward_policy.model.load_state_dict( torch.load(backward_final, map_location=set_device(device)) ) - except ValueError: - print("No backward policy found") + + if forward_final is None and backward_final is None: + print("Warning: no checkpoints found in", str(ckpt_dir)) + return gflownet, config From df6d9af40964d88336827120acb1651f9c8e149f Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 21 Feb 2024 14:21:02 -0500 Subject: [PATCH 058/106] `empty_ok=False` arg --- gflownet/utils/common.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 531cae9c6..c3e34d6aa 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -310,6 +310,7 @@ def load_gflow_net_from_run_path( print_config=False, device="cuda", load_final_ckpt=True, + empty_ok=False, ): """ Load GFlowNet from a run path (directory with a `.hydra` directory inside). @@ -326,11 +327,18 @@ def load_gflow_net_from_run_path( Device to which the models should be moved, by default "cuda". load_final_ckpt : bool, optional Whether to load the final models, by default True. + empty_ok : bool, optional + Whether to allow the checkpoints directory to be empty, by default False. Returns ------- Tuple[GFN, DictConfig] Loaded GFlowNet and the loaded config. + + Raises + ------ + ValueError + If no checkpoints are found in the directory. """ run_path = resolve_path(run_path) hydra_dir = run_path / ".hydra" @@ -375,6 +383,8 @@ def load_gflow_net_from_run_path( ) if forward_final is None and backward_final is None: + if not empty_ok: + raise ValueError("No checkpoints found in", str(ckpt_dir)) print("Warning: no checkpoints found in", str(ckpt_dir)) return gflownet, config From 626d34c07a4f7d2e4d4482b4a99134fb17bba737 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 21 Feb 2024 14:24:20 -0500 Subject: [PATCH 059/106] clean up example --- gflownet/evaluator/base.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 95f6ae53c..68f95f9e5 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -835,22 +835,21 @@ def eval_and_log_top_k(self, it): # Try using the GFlowNetEvaluator by running this script from the root: # $ ipython # In [1]: run gflownet/evaluator/base.py + # + # Note: this will not work on previous checkpoints whose config does not contain an + # `eval` entry, you have to run one. Add `eval.checkpoint_period=10` to quickly + # have a checkpoint to test. from pathlib import Path - # Demo run: - # $ python main.py user=$USER \ - # +experiments=simple_tetris \ - # logger.do.online=False \ - # eval.checkpoints_period=100 - scratch = Path("/network/scratch/s/schmidtv") - run_dirs = scratch / "crystals/logs/" - gfn_run_dir = run_dirs / "2024-02-20_19-31-50" # simple_tetris run + gfn_run_dir = "PUT_YOUR_RUN_DIR_HERE" # a run dir contains a .hydra folder gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) results = gfne.eval() + for name, metric in results["metrics"].items(): print(f"{name:20}: {metric:.4f}") + print( "Available figures in results['figs']:", ", ".join([fname for fname, fig in results["figs"].items() if fig is not None]) From a7018bb8fb61f3856a8e5797ea3b65b32bfec1ee Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Wed, 21 Feb 2024 14:40:51 -0500 Subject: [PATCH 060/106] document constants --- gflownet/evaluator/base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 68f95f9e5..e3afedd23 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -57,7 +57,21 @@ "requirements": ["log_probs"], }, } +""" +All metrics that can be computed by the GFlowNetEvaluator. Structured as a dict with the +metric names as keys and the metric display names and requirements as values. + +Requirements are used to decide which kind of data / samples is required to compute the +metric. + +Display names are used to log the metrics and to display them in the console. +""" + ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) +""" +Union of all requirements of all metrics in `METRICS`. Computed from +:py:const:`METRICS`. +""" class GFlowNetEvaluator: From b6832dad49e9eec7e5c16847559b54dadb6f5ee5 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 14:22:02 -0500 Subject: [PATCH 061/106] eval top k uses dict data structure --- gflownet/evaluator/base.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index e3afedd23..c9d00c748 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -714,9 +714,10 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): Returns ------- - tuple[dict, dict[str, plt.Figure], dict] - Computed dict of metrics, and figures (as {str: plt.Figure}), and optionally - (only once) summary metrics. + dict + Computed dict of metrics, and figures, and optionally (only once) summary + metrics. Schema: ``{"metrics": {str: float}, "figs": {str: plt.Figure}, + "summary": {str: float}}``. """ # only do random top k plots & metrics once do_random = it // self.logger.test.top_k_period == 1 @@ -800,7 +801,11 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): figs = {f: n for f, n in zip(figs, fig_names)} - return metrics, figs, summary + return { + "metrics": metrics, + "figs": figs, + "summary": summary, + } def eval_and_log(self, it, metrics=None): """ @@ -839,10 +844,12 @@ def eval_and_log_top_k(self, it): Current iteration step, by default None. """ - metrics, figs, summary = self.eval_top_k(it) - self.logger.log_plots(figs, it, use_context=self.use_context) - self.logger.log_metrics(metrics, use_context=self.use_context, step=it) - self.logger.log_summary(summary) + results = self.eval_top_k(it) + self.logger.log_plots(results["figs"], it, use_context=self.use_context) + self.logger.log_metrics( + results["metrics"], use_context=self.use_context, step=it + ) + self.logger.log_summary(results["summary"]) if __name__ == "__main__": From b6527f8e15eab320cb695c2e34d84948eac1dc11 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 14:22:44 -0500 Subject: [PATCH 062/106] improve docstrings --- gflownet/evaluator/base.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index c9d00c748..0e48ab867 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -434,11 +434,12 @@ def plot(self, x_sampled=None, kde_pred=None, kde_true=None, **plot_kwargs): Extend this method to add more plots: - ```python - def plot(self, x_sampled, kde_pred, kde_true, **plot_kwargs): - figs = super().plot(x_sampled, kde_pred, kde_true, **plot_kwargs) figs["My - custom plot"] = my_custom_plot_function(x_sampled, kde_pred) return figs - ``` + .. code-block:: python + + def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): + figs = super().plot(x_sampled, kde_pred, kde_true, plot_kwargs) + figs["My custom plot"] = my_custom_plot_function(x_sampled, kde_pred) + return figs Parameters ---------- @@ -450,6 +451,8 @@ def plot(self, x_sampled, kde_pred, kde_true, **plot_kwargs): True KDE. plot_kwargs : dict Additional keyword arguments to pass to the plotting methods. + kwargs : dict + Catch-all for additional arguments. Returns ------- @@ -631,13 +634,13 @@ def eval(self, metrics=None, **plot_kwargs): Extand in subclasses to add more metrics and plots: - ```python - def eval(self, metrics=None, **plot_kwargs): - result = super().eval(metrics=metrics, **plot_kwargs) - result["metrics"]["my_custom_metric"] = my_custom_metric_function() - result["figs"]["My custom plot"] = my_custom_plot_function() - return result - ``` + .. code-block:: python + + def eval(self, metrics=None, **plot_kwargs): + result = super().eval(metrics=metrics, **plot_kwargs) + result["metrics"]["my_custom_metric"] = my_custom_metric_function() + result["figs"]["My custom plot"] = my_custom_plot_function() + return result Parameters ---------- From 9e09b63b6cd996ed2a3ad4e306bd5e01043e14a6 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 14:23:05 -0500 Subject: [PATCH 063/106] add `update_all_metrics_and_requirements` --- gflownet/evaluator/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 0e48ab867..793036da1 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -104,9 +104,19 @@ def __init__(self, **kwargs): self.reqs = set() self.metrics = self.reqs = _sentinel + + self.update_all_metrics_and_requirements() + self.metrics = self.make_metrics(self.config.metrics) self.reqs = self.make_requirements() + def update_all_metrics_and_requirements(self): + """ + Method to be implemented by subclasses to update the global dict of metrics and + requirements. + """ + pass + @classmethod def from_dir( cls: "GFlowNetEvaluator", From 47c699212ad99a5162ee6197cc61ad5476e949c0 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 14:23:29 -0500 Subject: [PATCH 064/106] have dedicated `plot_kwargs` --- gflownet/evaluator/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 793036da1..8f5fc510b 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -433,7 +433,9 @@ def should_checkpoint(self, step): else: return not step % self.config.checkpoints_period - def plot(self, x_sampled=None, kde_pred=None, kde_true=None, **plot_kwargs): + def plot( + self, x_sampled=None, kde_pred=None, kde_true=None, plot_kwargs={}, **kwargs + ): """ Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which will be logged. From 4985d9f48ef54456d644fd0261cf382c23dec1b5 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 14:24:01 -0500 Subject: [PATCH 065/106] standardize `{"metrics": {}, "data": {}}` return pattern --- gflownet/evaluator/base.py | 66 +++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 8f5fc510b..57d659217 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -537,13 +537,17 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): -logprobs_std.mean() / logprobs_x_tt.mean() ).item() - return lp_metrics + return { + "metrics": lp_metrics, + } def compute_density_metrics(self, x_tt, dict_tt, metrics=None): gfn = self.gfn_agent metrics = self.make_metrics(metrics) density_metrics = {} + density_data = {} + x_sampled = density_true = density_pred = None if gfn.buffer.test_type is not None and gfn.buffer.test_type == "all": @@ -609,15 +613,18 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): density_true = np.exp(log_density_true) density_pred = np.exp(log_density_pred) - density_metrics["kde_pred"] = kde_pred - density_metrics["kde_true"] = kde_true + density_data["kde_pred"] = kde_pred + density_data["kde_true"] = kde_true else: density_metrics["l1"] = gfn.l1 density_metrics["kl"] = gfn.kl density_metrics["jsd"] = gfn.jsd - density_metrics["x_sampled"] = x_sampled - return density_metrics + density_data["x_sampled"] = x_sampled + return { + "metrics": density_metrics, + "data": density_data, + } # L1 error density_metrics["l1"] = np.abs(density_pred - density_true).mean() @@ -634,8 +641,12 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): density_pred * (log_density_pred - log_mean_dens) ) - density_metrics["x_sampled"] = x_sampled - return density_metrics + density_data["x_sampled"] = x_sampled + + return { + "metrics": density_metrics, + "data": density_data, + } def eval(self, metrics=None, **plot_kwargs): """ @@ -672,18 +683,16 @@ def eval(self, metrics=None, **plot_kwargs): metrics = self.make_metrics(metrics) reqs = self.make_requirements(metrics=metrics) - all_metrics = {} - x_sampled = kde_pred = kde_true = None - figs = {} - if gfn.buffer.test_pkl is None: - result = { + return { "metrics": { k: getattr(gfn, k) if hasattr(gfn, k) else None for k in metrics }, - "figs": figs, + "data": {}, } - return result + + all_data = {} + all_metrics = {} with open(gfn.buffer.test_pkl, "rb") as f: dict_tt = pickle.load(f) @@ -693,23 +702,20 @@ def eval(self, metrics=None, **plot_kwargs): # likelihood of the data according the the GFlowNet policy; and NLL. # TODO: organise code for better efficiency and readability if "log_probs" in reqs: - lp_metrics = self.compute_log_prob_metrics(x_tt, metrics=metrics) - all_metrics.update(lp_metrics) + lp_results = self.compute_log_prob_metrics(x_tt, metrics=metrics) + all_metrics.update(lp_results.get("metrics", {})) + all_data.update(lp_results.get("data", {})) if "density" in reqs: - density_metrics = self.compute_density_metrics( + density_results = self.compute_density_metrics( x_tt, dict_tt, metrics=metrics ) - x_sampled = density_metrics.pop("x_sampled", x_sampled) - kde_pred = density_metrics.pop("kde_pred", kde_pred) - kde_true = density_metrics.pop("kde_true", kde_true) - all_metrics.update(density_metrics) - - figs = self.plot(x_sampled=x_sampled, kde_pred=kde_pred, kde_true=kde_true) + all_metrics.update(density_results.get("metrics", {})) + all_data.update(density_results.get("data", {})) return { "metrics": all_metrics, - "figs": figs, + "data": all_data, } @torch.no_grad() @@ -837,16 +843,18 @@ def eval_and_log(self, it, metrics=None): List of metrics to compute, by default the evaluator's `metrics` attribute. """ gfn = self.gfn_agent - result = self.eval(metrics=metrics) - for m, v in result["metrics"].items(): + results = self.eval(metrics=metrics) + for m, v in results["metrics"].items(): setattr(gfn, m, v) mertics_to_log = { - METRICS[k]["display_name"]: v for k, v in result["metrics"].items() + METRICS[k]["display_name"]: v for k, v in results["metrics"].items() } + figs = self.plot(**results["data"]) + self.logger.log_metrics(mertics_to_log, it, gfn.use_context) - self.logger.log_plots(result["figs"], it, use_context=gfn.use_context) + self.logger.log_plots(figs, it, use_context=gfn.use_context) def eval_and_log_top_k(self, it): """ @@ -876,8 +884,6 @@ def eval_and_log_top_k(self, it): # `eval` entry, you have to run one. Add `eval.checkpoint_period=10` to quickly # have a checkpoint to test. - from pathlib import Path - gfn_run_dir = "PUT_YOUR_RUN_DIR_HERE" # a run dir contains a .hydra folder gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) From 054aa6131b99ce9de6922d8b7425dbc94964794d Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 14:24:23 -0500 Subject: [PATCH 066/106] work on docstrings example --- gflownet/evaluator/base.py | 125 +++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 57d659217..0dc5a0a2b 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -1,3 +1,128 @@ +""" +Base evaluator class for GFlowNetAgent. + +In charge of evaluating the GFlowNetAgent, computing metrics plotting figures and +optionally logging results using the GFlowNetAgent's logger. + +Only the :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.from_dir` and +:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.from_agent` class methods should be +used to instantiate this class. + +Create a new evaluator by subclassing this class and extending the :py:meth:`eval` +method to add more metrics and plots. + +Typical call stack: + +1. :py:meth:`gflownet.gflownet.GFlowNetAgent.train` calls the evaluator's +2. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.should_eval`. + If it returns ``True`` then :py:meth:`~gflownet.gflownet.GFlowNetAgent.train` calls +3. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` which itself calls +4. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` as + ``results = self.eval(metrics=None)`` and then ``figs = self.plot(**results["data"])`` +5. finally, :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` logs the + results using the GFlowNetAgent's logger as + ``self.logger.log_metrics(results["metrics"])`` and ``self.logger.log_plots(figs)``. + +Example +------- + +.. code-block:: python + + # gflownet/evaluator/my_evaluator.py + + from gflownet.evaluator.base import GFlowNetEvaluator, METRICS, ALL_REQS + + class MyEvaluator(GFlowNetEvaluator): + + def update_all_metrics_and_requirements(self): + global METRICS, ALL_REQS + + METRICS["my_custom_metric"] = { + "display_name": "My custom metric", + "requirements": ["density", "new_req"], + } + + ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) + + + def my_custom_metric(self, some, arguments): + intermediate = some + arguments + + return { + "metrics": { + "my_custom_metric": intermediate ** (-0.5) + }, + "data": { + "some_other": some ** 2, + "arguments": arguments, + "intermediate": intermediate, + } + } + ... + + def my_custom_plot( + self, some_other=None, arguments=None, intermediate=None, **kwargs + ): + # whatever gets to **kwargs will be ignored, this is used to handle + # methods with varying signatures. + figs = {} + if some_other is not None: + f = plt.figure() + # some plotting procedure for some_other + figs["My Title"] = f + + if arguments is not None: + f = plt.figure() + # some other plotting procedure with both + figs["My Other Title"] = f + elif arguments is not None: + f = plt.figure() + # some other plotting procedure with arguments + figs["My 3rd Title"] = f + + if intermediate is not None: + f = plt.figure() + # some other plotting procedure with intermediate + figs["My 4th Title"] = f + + return figs + + def plot(self, **kwargs): + figs = super().plot(**kwargs) + figs.update(self.my_custom_plot(**kwargs)) + + return figs + + def eval(self, metrics=None, **plot_kwargs): + gfn = self.gfn_agent + metrics = self.make_metrics(metrics) + reqs = self.make_requirements(metrics=metrics) + + results = super().eval(metrics=metrics, **plot_kwargs) + + if "new_req" in reqs: + my_results = self.my_custom_metric(some, arguments) + results["metrics"].update(my_results.get("metrics", {})) + results["data"].update(my_results.get("data", {})) + + return results + +In the previous example, the `update_all_metrics_and_requirements` method is used to +update the global `METRICS` and `ALL_REQS` variables. It will be called when the +`MyEvaluator` class is instantiated, in the init of `BaseEvaluator`. + +By defining a new requirement, you ensure that the new metrics and plots will only be +computed if user asks for a metric that requires such computations. + +By default, the train loop will call +:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` which itself calls +:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` so if you override ``eval()`` +as above, the new metrics and plots will be computed and logged. + +Similarly, `eval_and_log` will compute the ``dict`` of figures as +``fig_dict = self.plot(**results["data"])`` where ``results`` is the output of ``eval``. +""" + import copy import os import pickle From 0ac39c836c3e0aadcdae4a05a93259fecdf5201f Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 14:53:30 -0500 Subject: [PATCH 067/106] towards abstract / base pattern --- gflownet/evaluator/abstract.py | 622 +++++++++++++++++++++++++++++++++ gflownet/evaluator/base.py | 577 +----------------------------- 2 files changed, 629 insertions(+), 570 deletions(-) create mode 100644 gflownet/evaluator/abstract.py diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py new file mode 100644 index 000000000..77a435305 --- /dev/null +++ b/gflownet/evaluator/abstract.py @@ -0,0 +1,622 @@ +""" +Abstract evaluator class for GFlowNetAgent. + +Should not be used directly, but subclassed to implement specific evaluators for +different tasks and environments. + +See :py:class:`~gflownet.evaluator.base.GFlowNetEvaluator` for a default, +concrete implementation of this abstract class. + +This class handles some logic that will be the same for all evaluators. +The only requirements for a subclass are to implement the `plot` and `eval` methods +which will be called by the +:py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.eval_and_log` method. + +.. code-block:: python + + def eval_and_log(self, it, metrics=None): + gfn = self.gfn_agent + results = self.eval(metrics=metrics) + for m, v in results["metrics"].items(): + setattr(gfn, m, v) + + mertics_to_log = { + METRICS[k]["display_name"]: v for k, v in results["metrics"].items() + } + + figs = self.plot(**results["data"]) + + self.logger.log_metrics(mertics_to_log, it, gfn.use_context) + self.logger.log_plots(figs, it, use_context=gfn.use_context) +""" + +import copy +import os +import time +from abc import ABCMeta, abstractmethod +from typing import Union + +import torch + +from gflownet.utils.batch import Batch +from gflownet.utils.common import batch_with_rest, load_gflow_net_from_run_path + +_sentinel = object() + +METRICS = { + "l1": { + "display_name": "L1 error", + "requirements": ["density"], + }, + "kl": { + "display_name": "KL Div.", + "requirements": ["density"], + }, + "jsd": { + "display_name": "Jensen Shannon Div.", + "requirements": ["density"], + }, + "corr_prob_traj_rewards": { + "display_name": "Corr. (test probs., rewards)", + "requirements": ["log_probs", "reward_batch"], + }, + "var_logrewards_logp": { + "display_name": "Var(logR - logp) test", + "requirements": ["log_probs", "reward_batch"], + }, + "nll_tt": { + "display_name": "NLL of test data", + "requirements": ["log_probs"], + }, + "mean_logprobs_std": { + "display_name": "Mean BS Std(logp)", + "requirements": ["log_probs"], + }, + "mean_probs_std": { + "display_name": "Mean BS Std(p)", + "requirements": ["log_probs"], + }, + "logprobs_std_nll_ratio": { + "display_name": "BS Std(logp) / NLL", + "requirements": ["log_probs"], + }, +} +""" +All metrics that can be computed by a GFlowNetEvaluator. Structured as a dict with the +metric names as keys and the metric display names and requirements as values. + +Requirements are used to decide which kind of data / samples is required to compute the +metric. + +Display names are used to log the metrics and to display them in the console. +""" + +ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) +""" +Union of all requirements of all metrics in `METRICS`. Computed from +:py:const:`METRICS`. +""" + + +class GFlowNetAbstractEvaluator(metaclass=ABCMeta): + def __init__(self, **kwargs): + """ + Base evaluator class for GFlowNetAgent. + + In charge of evaluating the GFlowNetAgent, computing metrics plotting figures + and optionally logging results using the GFlowNetAgent's logger. + + Only the `from_dir` and `from_agent` class methods should be used to instantiate + this class. + + Raises + ------ + NotImplementedError + If the `sentinel` keyword argument is not `_sentinel`, which is used to + prevent instantiation of the base class without using the `from_dir` or + `from_agent` class methods. + + """ + if kwargs.get("sentinel") is not _sentinel: + raise NotImplementedError( + "Base evaluator class should not be instantiated. Use " + + "GFlowNetEvaluator.from_dir or GFlowNetEvaluator.from_agent methods." + ) + self.gfn_agent = kwargs.get("gfn_agent") + self.config = self.gfn_agent.eval_config + self.logger = self.gfn_agent.logger + self.reqs = set() + + self.metrics = self.reqs = _sentinel + + self.update_all_metrics_and_requirements() + + self.metrics = self.make_metrics(self.config.metrics) + self.reqs = self.make_requirements() + + def update_all_metrics_and_requirements(self): + """ + Method to be implemented by subclasses to update the global dict of metrics and + requirements. + """ + pass + + @classmethod + def from_dir( + cls: "GFlowNetAbstractEvaluator", + path: Union[str, os.PathLike], + no_wandb: bool = True, + print_config: bool = False, + device: str = "cuda", + load_final_ckpt: bool = True, + ): + """ + Instantiate a GFlowNetEvaluator from a run directory. + + Parameters + ---------- + cls : GFlowNetEvaluator + Class to instantiate. + path : Union[str, os.PathLike] + Path to the run directory from which to load the GFlowNetAgent. + no_wandb : bool, optional + Prevent wandb initialization, by default True + print_config : bool, optional + Whether or not to print the resulting (loaded) config, by default False + device : str, optional + Device to use for the instantiated GFlowNetAgent, by default "cuda" + load_final_ckpt : bool, optional + Use the latest possible checkpoint available in the path, by default True + + Returns + ------- + GFlowNetEvaluator + Instance of GFlowNetEvaluator with the GFlowNetAgent loaded from the run. + """ + gfn_agent, _ = load_gflow_net_from_run_path( + path, + no_wandb=no_wandb, + print_config=print_config, + device=device, + load_final_ckpt=load_final_ckpt, + ) + return GFlowNetAbstractEvaluator.from_agent(gfn_agent) + + @classmethod + def from_agent(cls, gfn_agent): + """ + Instantiate a GFlowNetEvaluator from a GFlowNetAgent. + + Parameters + ---------- + cls : GFlowNetEvaluator + Evaluator class to instantiate. + gfn_agent : GFlowNetAgent + Instance of GFlowNetAgent to use for the GFlowNetEvaluator. + + Returns + ------- + GFlowNetEvaluator + Instance of GFlowNetEvaluator with the provided GFlowNetAgent. + """ + from gflownet.gflownet import GFlowNetAgent + + assert isinstance(gfn_agent, GFlowNetAgent), ( + "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " + + f"{type(gfn_agent)}." + ) + + return GFlowNetAbstractEvaluator(gfn_agent=gfn_agent, sentinel=_sentinel) + + def make_metrics(self, metrics=None): + """ + Parse metrics from a dict, list, a string or None. + + - If `None`, all metrics are selected. + - If a string, it can be a comma-separated list of metric names, with or without + spaces. + - If a list, it should be a list of metric names (keys of `METRICS`). + - If a dict, its keys should be metric names and its values will be ignored: + they will be assigned from `METRICS`. + + All metrics must be in `METRICS`. + + Parameters + ---------- + metrics : Union[str, List[str]], optional + Metrics to compute when running the `evaluator.eval()` function. Defaults to + None, i.e. all metrics in `METRICS` are computed. + + Returns + ------- + dict + Dictionary of metrics to compute, with the metric names as keys and the + metric display names and requirements as values. + + Raises + ------ + ValueError + If a metric name is not in `METRICS`. + """ + if metrics is None: + assert self.metrics is not _sentinel, ( + "Error setting self.metrics. This is likely due to the `metrics:`" + + " entry missing from your eval config. Set it to 'all' to compute all" + + " metrics or to a comma-separated list of metric names (eg 'l1, kl')." + ) + return self.metrics + + if not isinstance(metrics, (str, list, dict)): + raise ValueError( + "metrics should be None, a string, a list or a dict," + + f" but is {type(metrics)}." + ) + + if metrics == "all": + metrics = METRICS.keys() + + if isinstance(metrics, str): + if metrics == "": + raise ValueError( + "`metrics` should not be an empty string. " + + "Set to 'all' or a list of metric names or None (null in YAML)." + ) + if "," in metrics: + metrics = metrics.split(",") + else: + metrics = [metrics] + + if isinstance(metrics, dict): + metrics = metrics.keys() + + metrics = [m.strip() for m in metrics] + + for m in metrics: + if m not in METRICS: + raise ValueError(f"Unknown metric name: {m}") + + return {m: METRICS[m] for m in metrics} + + def make_requirements(self, reqs=None, metrics=None): + """ + Make requirements for the metrics to compute. + + 1. If `metrics` is provided, they must be as a dict of metrics. The requirements + are computed from the `requirements` attribute of the metrics. + + 2. Otherwise, the requirements are computed from the `reqs` argument: + - If `reqs` is `"all"`, all requirements of all metrics are computed. + - If `reqs` is `None`, the evaluator's `self.reqs` attribute is used. + - If `reqs` is a list, it is used as the requirements. + + Parameters + ---------- + reqs : Union[str, List[str]], optional + The metrics requirements. Either `"all"`, a list of requirements or `None` + to use the evaluator's `self.reqs` attribute. By default None + metrics : Union[str, List[str], dict], optional + The metrics to compute requirements for. If not a dict, will be passed to + `make_metrics`. By default None. + + Returns + ------- + set[str] + The set of requirements for the metrics. + """ + + if metrics is not None: + if not isinstance(metrics, dict): + metrics = self.make_metrics(metrics) + for m in metrics: + if m not in METRICS: + raise ValueError(f"Unknown metric name: {m}") + return set([r for m in metrics.values() for r in m["requirements"]]) + + if isinstance(reqs, str): + if reqs == "all": + reqs = ALL_REQS.copy() + else: + raise ValueError( + "reqs should be 'all', a list of requirements or None, but is " + + f"{reqs}." + ) + if reqs is None: + if self.reqs is _sentinel: + if not isinstance(self.metrics, dict): + raise ValueError( + "Cannot compute requirements from `None` without the `metrics`" + + " argument or the `self.metrics` attribute set to a dict" + + " of metrics." + ) + self.reqs = set( + [r for m in self.metrics.values() for r in m["requirements"]] + ) + reqs = self.reqs + if isinstance(reqs, list): + reqs = set(reqs) + + assert isinstance( + reqs, set + ), f"reqs should be None, 'all', a set or a list, but is {type(reqs)}" + + assert all([isinstance(r, str) for r in reqs]), ( + "All elements of reqs should be strings, but are " + + f"{[type(r) for r in reqs]}" + ) + + for r in reqs: + if r not in ALL_REQS: + raise ValueError(f"Unknown requirement: {r}") + + return reqs + + def should_log_train(self, step): + """ + Check if training logs should be done at the current step. The decision is based + on the `self.config.train.period` attribute. + + Set `self.config.train.period` to `None` or a negative value to disable + training. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if train logging should be done at the current step, False otherwise. + """ + if self.config.train_log_period is None or self.config.train_log_period <= 0: + return False + else: + return step % self.config.train_log_period == 0 + + def should_eval(self, step): + """ + Check if testing should be done at the current step. The decision is based on + the `self.config.test.period` attribute. + + Set `self.config.test.first_it` to `True` if testing should be done at the first + iteration step. Otherwise, testing will be done aftter `self.config.test.period` + steps. + + Set `self.config.test.period` to `None` or a negative value to disable testing. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if testing should be done at the current step, False otherwise. + """ + if self.config.period is None or self.config.period <= 0: + return False + elif step == 1 and self.config.first_it: + return True + else: + return step % self.config.period == 0 + + def should_eval_top_k(self, step): + """ + Check if top k plots and metrics should be done at the current step. The + decision is based on the `self.config.test.top_k` and + `self.config.test.top_k_period` attributes. + + Set `self.config.test.top_k` to `None` or a negative value to disable top k + plots and metrics. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if top k plots and metrics should be done at the current step, False + """ + if self.config.top_k is None or self.config.top_k <= 0: + return False + + if self.config.top_k_period is None or self.config.top_k_period <= 0: + return False + + if step == 1 and self.config.first_it: + return True + + return step % self.config.top_k_period == 0 + + def should_checkpoint(self, step): + """ + Check if checkpoints should be done at the current step. The decision is based + on the `self.checkpoints.period` attribute. + + Set `self.checkpoints.period` to `None` or a negative value to disable + checkpoints. + + Parameters + ---------- + step : int + Current iteration step. + + Returns + ------- + bool + True if checkpoints should be done at the current step, False otherwise. + """ + if ( + self.config.checkpoints_period is None + or self.config.checkpoints_period <= 0 + ): + return False + else: + return not step % self.config.checkpoints_period + + @abstractmethod + def plot(self, **kwargs): + pass + + @abstractmethod + def eval(self, metrics=None, **plot_kwargs): + pass + + def eval_and_log(self, it, metrics=None): + """ + Evaluate the GFlowNetAgent and log the results with its logger. + + Will call `self.eval()` and log the results using the GFlowNetAgent's logger + `log_metrics()` and `log_plots()` methods. + + Parameters + ---------- + it : int + Current iteration step. + metrics : Union[str, List[str]], optional + List of metrics to compute, by default the evaluator's `metrics` attribute. + """ + gfn = self.gfn_agent + results = self.eval(metrics=metrics) + for m, v in results["metrics"].items(): + setattr(gfn, m, v) + + mertics_to_log = { + METRICS[k]["display_name"]: v for k, v in results["metrics"].items() + } + + figs = self.plot(**results["data"]) + + self.logger.log_metrics(mertics_to_log, it, gfn.use_context) + self.logger.log_plots(figs, it, use_context=gfn.use_context) + + @torch.no_grad() + def eval_top_k(self, it, gfn_states=None, random_states=None): + """ + Sample from the current GFN and compute metrics and plots for the top k states + according to both the energy and the reward. + + Parameters + ---------- + it : int + current iteration + gfn_states : list, optional + Already sampled gfn states. Defaults to None. + random_states : list, optional + Already sampled random states. Defaults to None. + + Returns + ------- + dict + Computed dict of metrics, and figures, and optionally (only once) summary + metrics. Schema: ``{"metrics": {str: float}, "figs": {str: plt.Figure}, + "summary": {str: float}}``. + """ + # only do random top k plots & metrics once + do_random = it // self.logger.test.top_k_period == 1 + duration = None + summary = {} + prob = copy.deepcopy(self.random_action_prob) + gfn = self.gfn_agent + print() + if not gfn_states: + # sample states from the current gfn + batch = Batch(env=gfn.env, device=gfn.device, float_type=gfn.float) + gfn.random_action_prob = 0 + t = time.time() + print("Sampling from GFN...", end="\r") + for b in batch_with_rest(0, gfn.logger.test.n_top_k, gfn.batch_size_total): + sub_batch, _ = gfn.sample_batch(n_forward=len(b), train=False) + batch.merge(sub_batch) + duration = time.time() - t + gfn_states = batch.get_terminating_states() + + # compute metrics and get plots + print("[eval_top_k] Making GFN plots...", end="\r") + metrics, figs, fig_names = gfn.env.top_k_metrics_and_plots( + gfn_states, gfn.logger.test.top_k, name="gflownet", step=it + ) + if duration: + metrics["gflownet top k sampling duration"] = duration + + if do_random: + # sample random states from uniform actions + if not random_states: + batch = Batch(env=gfn.env, device=gfn.device, float_type=gfn.float) + gfn.random_action_prob = 1.0 + print("[eval_top_k] Sampling at random...", end="\r") + for b in batch_with_rest( + 0, gfn.logger.test.n_top_k, gfn.batch_size_total + ): + sub_batch, _ = gfn.sample_batch(n_forward=len(b), train=False) + batch.merge(sub_batch) + # compute metrics and get plots + random_states = batch.get_terminating_states() + print("[eval_top_k] Making Random plots...", end="\r") + ( + random_metrics, + random_figs, + random_fig_names, + ) = gfn.env.top_k_metrics_and_plots( + random_states, gfn.logger.test.top_k, name="random", step=None + ) + # add to current metrics and plots + summary.update(random_metrics) + figs += random_figs + fig_names += random_fig_names + # compute training data metrics and get plots + print("[eval_top_k] Making train plots...", end="\r") + ( + train_metrics, + train_figs, + train_fig_names, + ) = gfn.env.top_k_metrics_and_plots( + None, gfn.logger.test.top_k, name="train", step=None + ) + # add to current metrics and plots + summary.update(train_metrics) + figs += train_figs + fig_names += train_fig_names + + gfn.random_action_prob = prob + + print(" " * 100, end="\r") + print("eval_top_k metrics:") + max_k = max([len(k) for k in (list(metrics.keys()) + list(summary.keys()))]) + 1 + print( + " • " + + "\n • ".join( + f"{k:{max_k}}: {v:.4f}" + for k, v in (list(metrics.items()) + list(summary.items())) + ) + ) + print() + + figs = {f: n for f, n in zip(figs, fig_names)} + + return { + "metrics": metrics, + "figs": figs, + "summary": summary, + } + + def eval_and_log_top_k(self, it): + """ + Evaluate the GFlowNetAgent's top k samples performance and log the results with + its logger. + + Parameters + ---------- + it : int + Current iteration step, by default None. + """ + + results = self.eval_top_k(it) + self.logger.log_plots(results["figs"], it, use_context=self.use_context) + self.logger.log_metrics( + results["metrics"], use_context=self.use_context, step=it + ) + self.logger.log_summary(results["summary"]) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 0dc5a0a2b..d4e024546 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -134,6 +134,12 @@ def eval(self, metrics=None, **plot_kwargs): import torch from scipy.special import logsumexp +from gflownet.evaluator.abstract import ( + ALL_REQS, + METRICS, + GFlowNetAbstractEvaluator, + _sentinel, +) from gflownet.utils.batch import Batch from gflownet.utils.common import ( batch_with_rest, @@ -142,421 +148,8 @@ def eval(self, metrics=None, **plot_kwargs): torch2np, ) -_sentinel = object() - -METRICS = { - "l1": { - "display_name": "L1 error", - "requirements": ["density"], - }, - "kl": { - "display_name": "KL Div.", - "requirements": ["density"], - }, - "jsd": { - "display_name": "Jensen Shannon Div.", - "requirements": ["density"], - }, - "corr_prob_traj_rewards": { - "display_name": "Corr. (test probs., rewards)", - "requirements": ["log_probs", "reward_batch"], - }, - "var_logrewards_logp": { - "display_name": "Var(logR - logp) test", - "requirements": ["log_probs", "reward_batch"], - }, - "nll_tt": { - "display_name": "NLL of test data", - "requirements": ["log_probs"], - }, - "mean_logprobs_std": { - "display_name": "Mean BS Std(logp)", - "requirements": ["log_probs"], - }, - "mean_probs_std": { - "display_name": "Mean BS Std(p)", - "requirements": ["log_probs"], - }, - "logprobs_std_nll_ratio": { - "display_name": "BS Std(logp) / NLL", - "requirements": ["log_probs"], - }, -} -""" -All metrics that can be computed by the GFlowNetEvaluator. Structured as a dict with the -metric names as keys and the metric display names and requirements as values. - -Requirements are used to decide which kind of data / samples is required to compute the -metric. - -Display names are used to log the metrics and to display them in the console. -""" - -ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) -""" -Union of all requirements of all metrics in `METRICS`. Computed from -:py:const:`METRICS`. -""" - - -class GFlowNetEvaluator: - def __init__(self, **kwargs): - """ - Base evaluator class for GFlowNetAgent. - - In charge of evaluating the GFlowNetAgent, computing metrics plotting figures - and optionally logging results using the GFlowNetAgent's logger. - - Only the `from_dir` and `from_agent` class methods should be used to instantiate - this class. - - Raises - ------ - NotImplementedError - If the `sentinel` keyword argument is not `_sentinel`, which is used to - prevent instantiation of the base class without using the `from_dir` or - `from_agent` class methods. - - """ - if kwargs.get("sentinel") is not _sentinel: - raise NotImplementedError( - "Base evaluator class should not be instantiated. Use " - + "GFlowNetEvaluator.from_dir or GFlowNetEvaluator.from_agent methods." - ) - self.gfn_agent = kwargs.get("gfn_agent") - self.config = self.gfn_agent.eval_config - self.logger = self.gfn_agent.logger - self.reqs = set() - - self.metrics = self.reqs = _sentinel - - self.update_all_metrics_and_requirements() - - self.metrics = self.make_metrics(self.config.metrics) - self.reqs = self.make_requirements() - - def update_all_metrics_and_requirements(self): - """ - Method to be implemented by subclasses to update the global dict of metrics and - requirements. - """ - pass - - @classmethod - def from_dir( - cls: "GFlowNetEvaluator", - path: Union[str, os.PathLike], - no_wandb: bool = True, - print_config: bool = False, - device: str = "cuda", - load_final_ckpt: bool = True, - ): - """ - Instantiate a GFlowNetEvaluator from a run directory. - - Parameters - ---------- - cls : GFlowNetEvaluator - Class to instantiate. - path : Union[str, os.PathLike] - Path to the run directory from which to load the GFlowNetAgent. - no_wandb : bool, optional - Prevent wandb initialization, by default True - print_config : bool, optional - Whether or not to print the resulting (loaded) config, by default False - device : str, optional - Device to use for the instantiated GFlowNetAgent, by default "cuda" - load_final_ckpt : bool, optional - Use the latest possible checkpoint available in the path, by default True - - Returns - ------- - GFlowNetEvaluator - Instance of GFlowNetEvaluator with the GFlowNetAgent loaded from the run. - """ - gfn_agent, _ = load_gflow_net_from_run_path( - path, - no_wandb=no_wandb, - print_config=print_config, - device=device, - load_final_ckpt=load_final_ckpt, - ) - return GFlowNetEvaluator.from_agent(gfn_agent) - - @classmethod - def from_agent(cls, gfn_agent): - """ - Instantiate a GFlowNetEvaluator from a GFlowNetAgent. - - Parameters - ---------- - cls : GFlowNetEvaluator - Evaluator class to instantiate. - gfn_agent : GFlowNetAgent - Instance of GFlowNetAgent to use for the GFlowNetEvaluator. - - Returns - ------- - GFlowNetEvaluator - Instance of GFlowNetEvaluator with the provided GFlowNetAgent. - """ - from gflownet.gflownet import GFlowNetAgent - - assert isinstance(gfn_agent, GFlowNetAgent), ( - "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " - + f"{type(gfn_agent)}." - ) - - return GFlowNetEvaluator(gfn_agent=gfn_agent, sentinel=_sentinel) - - def make_metrics(self, metrics=None): - """ - Parse metrics from a dict, list, a string or None. - - - If `None`, all metrics are selected. - - If a string, it can be a comma-separated list of metric names, with or without - spaces. - - If a list, it should be a list of metric names (keys of `METRICS`). - - If a dict, its keys should be metric names and its values will be ignored: - they will be assigned from `METRICS`. - - All metrics must be in `METRICS`. - - Parameters - ---------- - metrics : Union[str, List[str]], optional - Metrics to compute when running the `evaluator.eval()` function. Defaults to - None, i.e. all metrics in `METRICS` are computed. - - Returns - ------- - dict - Dictionary of metrics to compute, with the metric names as keys and the - metric display names and requirements as values. - - Raises - ------ - ValueError - If a metric name is not in `METRICS`. - """ - if metrics is None: - assert self.metrics is not _sentinel, ( - "Error setting self.metrics. This is likely due to the `metrics:`" - + " entry missing from your eval config. Set it to 'all' to compute all" - + " metrics or to a comma-separated list of metric names (eg 'l1, kl')." - ) - return self.metrics - - if not isinstance(metrics, (str, list, dict)): - raise ValueError( - "metrics should be None, a string, a list or a dict," - + f" but is {type(metrics)}." - ) - - if metrics == "all": - metrics = METRICS.keys() - - if isinstance(metrics, str): - if metrics == "": - raise ValueError( - "`metrics` should not be an empty string. " - + "Set to 'all' or a list of metric names or None (null in YAML)." - ) - if "," in metrics: - metrics = metrics.split(",") - else: - metrics = [metrics] - - if isinstance(metrics, dict): - metrics = metrics.keys() - - metrics = [m.strip() for m in metrics] - - for m in metrics: - if m not in METRICS: - raise ValueError(f"Unknown metric name: {m}") - - return {m: METRICS[m] for m in metrics} - - def make_requirements(self, reqs=None, metrics=None): - """ - Make requirements for the metrics to compute. - - 1. If `metrics` is provided, they must be as a dict of metrics. The requirements - are computed from the `requirements` attribute of the metrics. - - 2. Otherwise, the requirements are computed from the `reqs` argument: - - If `reqs` is `"all"`, all requirements of all metrics are computed. - - If `reqs` is `None`, the evaluator's `self.reqs` attribute is used. - - If `reqs` is a list, it is used as the requirements. - - Parameters - ---------- - reqs : Union[str, List[str]], optional - The metrics requirements. Either `"all"`, a list of requirements or `None` - to use the evaluator's `self.reqs` attribute. By default None - metrics : Union[str, List[str], dict], optional - The metrics to compute requirements for. If not a dict, will be passed to - `make_metrics`. By default None. - - Returns - ------- - set[str] - The set of requirements for the metrics. - """ - - if metrics is not None: - if not isinstance(metrics, dict): - metrics = self.make_metrics(metrics) - for m in metrics: - if m not in METRICS: - raise ValueError(f"Unknown metric name: {m}") - return set([r for m in metrics.values() for r in m["requirements"]]) - - if isinstance(reqs, str): - if reqs == "all": - reqs = ALL_REQS.copy() - else: - raise ValueError( - "reqs should be 'all', a list of requirements or None, but is " - + f"{reqs}." - ) - if reqs is None: - if self.reqs is _sentinel: - if not isinstance(self.metrics, dict): - raise ValueError( - "Cannot compute requirements from `None` without the `metrics`" - + " argument or the `self.metrics` attribute set to a dict" - + " of metrics." - ) - self.reqs = set( - [r for m in self.metrics.values() for r in m["requirements"]] - ) - reqs = self.reqs - if isinstance(reqs, list): - reqs = set(reqs) - - assert isinstance( - reqs, set - ), f"reqs should be None, 'all', a set or a list, but is {type(reqs)}" - - assert all([isinstance(r, str) for r in reqs]), ( - "All elements of reqs should be strings, but are " - + f"{[type(r) for r in reqs]}" - ) - - for r in reqs: - if r not in ALL_REQS: - raise ValueError(f"Unknown requirement: {r}") - return reqs - - def should_log_train(self, step): - """ - Check if training logs should be done at the current step. The decision is based - on the `self.config.train.period` attribute. - - Set `self.config.train.period` to `None` or a negative value to disable - training. - - Parameters - ---------- - step : int - Current iteration step. - - Returns - ------- - bool - True if train logging should be done at the current step, False otherwise. - """ - if self.config.train_log_period is None or self.config.train_log_period <= 0: - return False - else: - return step % self.config.train_log_period == 0 - - def should_eval(self, step): - """ - Check if testing should be done at the current step. The decision is based on - the `self.config.test.period` attribute. - - Set `self.config.test.first_it` to `True` if testing should be done at the first - iteration step. Otherwise, testing will be done aftter `self.config.test.period` - steps. - - Set `self.config.test.period` to `None` or a negative value to disable testing. - - Parameters - ---------- - step : int - Current iteration step. - - Returns - ------- - bool - True if testing should be done at the current step, False otherwise. - """ - if self.config.period is None or self.config.period <= 0: - return False - elif step == 1 and self.config.first_it: - return True - else: - return step % self.config.period == 0 - - def should_eval_top_k(self, step): - """ - Check if top k plots and metrics should be done at the current step. The - decision is based on the `self.config.test.top_k` and - `self.config.test.top_k_period` attributes. - - Set `self.config.test.top_k` to `None` or a negative value to disable top k - plots and metrics. - - Parameters - ---------- - step : int - Current iteration step. - - Returns - ------- - bool - True if top k plots and metrics should be done at the current step, False - """ - if self.config.top_k is None or self.config.top_k <= 0: - return False - - if self.config.top_k_period is None or self.config.top_k_period <= 0: - return False - - if step == 1 and self.config.first_it: - return True - - return step % self.config.top_k_period == 0 - - def should_checkpoint(self, step): - """ - Check if checkpoints should be done at the current step. The decision is based - on the `self.checkpoints.period` attribute. - - Set `self.checkpoints.period` to `None` or a negative value to disable - checkpoints. - - Parameters - ---------- - step : int - Current iteration step. - - Returns - ------- - bool - True if checkpoints should be done at the current step, False otherwise. - """ - if ( - self.config.checkpoints_period is None - or self.config.checkpoints_period <= 0 - ): - return False - else: - return not step % self.config.checkpoints_period +class GFlowNetEvaluator(GFlowNetAbstractEvaluator): def plot( self, x_sampled=None, kde_pred=None, kde_true=None, plot_kwargs={}, **kwargs @@ -843,162 +436,6 @@ def eval(self, metrics=None, **plot_kwargs): "data": all_data, } - @torch.no_grad() - def eval_top_k(self, it, gfn_states=None, random_states=None): - """ - Sample from the current GFN and compute metrics and plots for the top k states - according to both the energy and the reward. - - Parameters - ---------- - it : int - current iteration - gfn_states : list, optional - Already sampled gfn states. Defaults to None. - random_states : list, optional - Already sampled random states. Defaults to None. - - Returns - ------- - dict - Computed dict of metrics, and figures, and optionally (only once) summary - metrics. Schema: ``{"metrics": {str: float}, "figs": {str: plt.Figure}, - "summary": {str: float}}``. - """ - # only do random top k plots & metrics once - do_random = it // self.logger.test.top_k_period == 1 - duration = None - summary = {} - prob = copy.deepcopy(self.random_action_prob) - gfn = self.gfn_agent - print() - if not gfn_states: - # sample states from the current gfn - batch = Batch(env=gfn.env, device=gfn.device, float_type=gfn.float) - gfn.random_action_prob = 0 - t = time.time() - print("Sampling from GFN...", end="\r") - for b in batch_with_rest(0, gfn.logger.test.n_top_k, gfn.batch_size_total): - sub_batch, _ = gfn.sample_batch(n_forward=len(b), train=False) - batch.merge(sub_batch) - duration = time.time() - t - gfn_states = batch.get_terminating_states() - - # compute metrics and get plots - print("[eval_top_k] Making GFN plots...", end="\r") - metrics, figs, fig_names = gfn.env.top_k_metrics_and_plots( - gfn_states, gfn.logger.test.top_k, name="gflownet", step=it - ) - if duration: - metrics["gflownet top k sampling duration"] = duration - - if do_random: - # sample random states from uniform actions - if not random_states: - batch = Batch(env=gfn.env, device=gfn.device, float_type=gfn.float) - gfn.random_action_prob = 1.0 - print("[eval_top_k] Sampling at random...", end="\r") - for b in batch_with_rest( - 0, gfn.logger.test.n_top_k, gfn.batch_size_total - ): - sub_batch, _ = gfn.sample_batch(n_forward=len(b), train=False) - batch.merge(sub_batch) - # compute metrics and get plots - random_states = batch.get_terminating_states() - print("[eval_top_k] Making Random plots...", end="\r") - ( - random_metrics, - random_figs, - random_fig_names, - ) = gfn.env.top_k_metrics_and_plots( - random_states, gfn.logger.test.top_k, name="random", step=None - ) - # add to current metrics and plots - summary.update(random_metrics) - figs += random_figs - fig_names += random_fig_names - # compute training data metrics and get plots - print("[eval_top_k] Making train plots...", end="\r") - ( - train_metrics, - train_figs, - train_fig_names, - ) = gfn.env.top_k_metrics_and_plots( - None, gfn.logger.test.top_k, name="train", step=None - ) - # add to current metrics and plots - summary.update(train_metrics) - figs += train_figs - fig_names += train_fig_names - - gfn.random_action_prob = prob - - print(" " * 100, end="\r") - print("eval_top_k metrics:") - max_k = max([len(k) for k in (list(metrics.keys()) + list(summary.keys()))]) + 1 - print( - " • " - + "\n • ".join( - f"{k:{max_k}}: {v:.4f}" - for k, v in (list(metrics.items()) + list(summary.items())) - ) - ) - print() - - figs = {f: n for f, n in zip(figs, fig_names)} - - return { - "metrics": metrics, - "figs": figs, - "summary": summary, - } - - def eval_and_log(self, it, metrics=None): - """ - Evaluate the GFlowNetAgent and log the results with its logger. - - Will call `self.eval()` and log the results using the GFlowNetAgent's logger - `log_metrics()` and `log_plots()` methods. - - Parameters - ---------- - it : int - Current iteration step. - metrics : Union[str, List[str]], optional - List of metrics to compute, by default the evaluator's `metrics` attribute. - """ - gfn = self.gfn_agent - results = self.eval(metrics=metrics) - for m, v in results["metrics"].items(): - setattr(gfn, m, v) - - mertics_to_log = { - METRICS[k]["display_name"]: v for k, v in results["metrics"].items() - } - - figs = self.plot(**results["data"]) - - self.logger.log_metrics(mertics_to_log, it, gfn.use_context) - self.logger.log_plots(figs, it, use_context=gfn.use_context) - - def eval_and_log_top_k(self, it): - """ - Evaluate the GFlowNetAgent's top k samples performance and log the results with - its logger. - - Parameters - ---------- - it : int - Current iteration step, by default None. - """ - - results = self.eval_top_k(it) - self.logger.log_plots(results["figs"], it, use_context=self.use_context) - self.logger.log_metrics( - results["metrics"], use_context=self.use_context, step=it - ) - self.logger.log_summary(results["summary"]) - if __name__ == "__main__": # Try using the GFlowNetEvaluator by running this script from the root: From d73de327187c66b6d7561db02a934f0848603683 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 15:15:22 -0500 Subject: [PATCH 068/106] update example docstring --- gflownet/evaluator/base.py | 183 ++++++++++++++++++++----------------- 1 file changed, 99 insertions(+), 84 deletions(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index d4e024546..f3a9087c5 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -4,9 +4,11 @@ In charge of evaluating the GFlowNetAgent, computing metrics plotting figures and optionally logging results using the GFlowNetAgent's logger. -Only the :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.from_dir` and -:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.from_agent` class methods should be -used to instantiate this class. +.. important:: + + Only the :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_dir` + and :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_agent` + class methods should be used to instantiate this class. Create a new evaluator by subclassing this class and extending the :py:meth:`eval` method to add more metrics and plots. @@ -14,26 +16,52 @@ Typical call stack: 1. :py:meth:`gflownet.gflownet.GFlowNetAgent.train` calls the evaluator's + 2. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.should_eval`. - If it returns ``True`` then :py:meth:`~gflownet.gflownet.GFlowNetAgent.train` calls + If it returns ``True`` then :py:meth:`~gflownet.gflownet.GFlowNetAgent.train` calls + 3. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` which itself calls + 4. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` as - ``results = self.eval(metrics=None)`` and then ``figs = self.plot(**results["data"])`` + ``results = self.eval(metrics=None)`` and then + ``figs = self.plot(**results["data"])`` + 5. finally, :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` logs the - results using the GFlowNetAgent's logger as - ``self.logger.log_metrics(results["metrics"])`` and ``self.logger.log_plots(figs)``. + results using the GFlowNetAgent's logger as + ``self.logger.log_metrics(results["metrics"])`` and ``self.logger.log_plots(figs)``. Example ------- .. code-block:: python - # gflownet/evaluator/my_evaluator.py + # How to create a new evaluator: + from gflownet.evaluator.base import GFlowNetEvaluator + + gfn_run_dir = "PUT_YOUR_RUN_DIR_HERE" # a run dir contains a .hydra folder + gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) + results = gfne.eval() + + for name, metric in results["metrics"].items(): + print(f"{name:20}: {metric:.4f}") + + data = results.get("data", {}) + plots = gfne.plot(**data) + + print( + "Available figures in plots:", + ", ".join([fname for fname, fig in plots.items() if fig is not None]) + or "None", + ) + + +.. code-block:: python + + # gflownet/evaluator/my_evaluator.py from gflownet.evaluator.base import GFlowNetEvaluator, METRICS, ALL_REQS class MyEvaluator(GFlowNetEvaluator): - def update_all_metrics_and_requirements(self): global METRICS, ALL_REQS @@ -123,92 +151,21 @@ def eval(self, metrics=None, **plot_kwargs): ``fig_dict = self.plot(**results["data"])`` where ``results`` is the output of ``eval``. """ -import copy -import os import pickle -import time from collections import defaultdict -from typing import Union import numpy as np import torch from scipy.special import logsumexp -from gflownet.evaluator.abstract import ( - ALL_REQS, - METRICS, - GFlowNetAbstractEvaluator, - _sentinel, -) -from gflownet.utils.batch import Batch -from gflownet.utils.common import ( - batch_with_rest, - load_gflow_net_from_run_path, - tfloat, - torch2np, -) +from gflownet.evaluator.abstract import ALL_REQS # noqa +from gflownet.evaluator.abstract import METRICS # noqa +from gflownet.evaluator.abstract import GFlowNetAbstractEvaluator +from gflownet.utils.common import tfloat, torch2np class GFlowNetEvaluator(GFlowNetAbstractEvaluator): - def plot( - self, x_sampled=None, kde_pred=None, kde_true=None, plot_kwargs={}, **kwargs - ): - """ - Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which - will be logged. - - By default, this method will call the `plot_reward_samples` method of the - GFlowNetAgent's environment, and the `plot_kde` method of the GFlowNetAgent's - environment if it exists for both the `kde_pred` and `kde_true` arguments. - - Extend this method to add more plots: - - .. code-block:: python - - def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): - figs = super().plot(x_sampled, kde_pred, kde_true, plot_kwargs) - figs["My custom plot"] = my_custom_plot_function(x_sampled, kde_pred) - return figs - - Parameters - ---------- - x_sampled : list, optional - List of sampled states. - kde_pred : sklearn.neighbors.KernelDensity - KDE policy as per `Environment.fit_kde` - kde_true : object - True KDE. - plot_kwargs : dict - Additional keyword arguments to pass to the plotting methods. - kwargs : dict - Catch-all for additional arguments. - - Returns - ------- - dict[str, plt.Figure] - Dictionary of figures to be logged. The keys are the figure names and the - values are the figures. - """ - gfn = self.gfn_agent - - fig_kde_pred = fig_kde_true = fig_reward_samples = None - - if hasattr(gfn.env, "plot_reward_samples") and x_sampled is not None: - fig_reward_samples = gfn.env.plot_reward_samples(x_sampled, **plot_kwargs) - - if hasattr(gfn.env, "plot_kde"): - if kde_pred is not None: - fig_kde_pred = gfn.env.plot_kde(kde_pred, **plot_kwargs) - if kde_true is not None: - fig_kde_true = gfn.env.plot_kde(kde_true, **plot_kwargs) - - return { - "True reward and GFlowNet samples": fig_reward_samples, - "GFlowNet KDE Policy": fig_kde_pred, - "Reward KDE": fig_kde_true, - } - def compute_log_prob_metrics(self, x_tt, metrics=None): gfn = self.gfn_agent metrics = self.make_metrics(metrics) @@ -436,6 +393,64 @@ def eval(self, metrics=None, **plot_kwargs): "data": all_data, } + def plot( + self, x_sampled=None, kde_pred=None, kde_true=None, plot_kwargs={}, **kwargs + ): + """ + Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which + will be logged. + + By default, this method will call the `plot_reward_samples` method of the + GFlowNetAgent's environment, and the `plot_kde` method of the GFlowNetAgent's + environment if it exists for both the `kde_pred` and `kde_true` arguments. + + Extend this method to add more plots: + + .. code-block:: python + + def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): + figs = super().plot(x_sampled, kde_pred, kde_true, plot_kwargs) + figs["My custom plot"] = my_custom_plot_function(x_sampled, kde_pred) + return figs + + Parameters + ---------- + x_sampled : list, optional + List of sampled states. + kde_pred : sklearn.neighbors.KernelDensity + KDE policy as per `Environment.fit_kde` + kde_true : object + True KDE. + plot_kwargs : dict + Additional keyword arguments to pass to the plotting methods. + kwargs : dict + Catch-all for additional arguments. + + Returns + ------- + dict[str, plt.Figure] + Dictionary of figures to be logged. The keys are the figure names and the + values are the figures. + """ + gfn = self.gfn_agent + + fig_kde_pred = fig_kde_true = fig_reward_samples = None + + if hasattr(gfn.env, "plot_reward_samples") and x_sampled is not None: + fig_reward_samples = gfn.env.plot_reward_samples(x_sampled, **plot_kwargs) + + if hasattr(gfn.env, "plot_kde"): + if kde_pred is not None: + fig_kde_pred = gfn.env.plot_kde(kde_pred, **plot_kwargs) + if kde_true is not None: + fig_kde_true = gfn.env.plot_kde(kde_true, **plot_kwargs) + + return { + "True reward and GFlowNet samples": fig_reward_samples, + "GFlowNet KDE Policy": fig_kde_pred, + "Reward KDE": fig_kde_true, + } + if __name__ == "__main__": # Try using the GFlowNetEvaluator by running this script from the root: From cdb66b38d91c312df8b22ae25933a16bcf4c85eb Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Fri, 1 Mar 2024 15:28:53 -0500 Subject: [PATCH 069/106] more docs --- gflownet/evaluator/__init__.py | 150 +++++++++++++++++++++++++++++++++ gflownet/evaluator/abstract.py | 6 +- gflownet/evaluator/base.py | 139 ------------------------------ 3 files changed, 154 insertions(+), 141 deletions(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index e69de29bb..0afeb7361 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -0,0 +1,150 @@ +""" +Create a new evaluator by subclassing this class and extending the :py:meth:`eval` +method to add more metrics and plots. + +.. important:: + + Only the :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_dir` + and :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_agent` + class methods should be used to instantiate this class. + +Typical call stack: + +1. :py:meth:`gflownet.gflownet.GFlowNetAgent.train` calls the evaluator's + +2. :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.should_eval`. + If it returns ``True`` then :py:meth:`~gflownet.gflownet.GFlowNetAgent.train` calls + +3. :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.eval_and_log` + which itself calls + +4. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` as + ``results = self.eval(metrics=None)`` and then + ``figs = self.plot(**results["data"])`` + +5. finally, :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` logs the + results using the GFlowNetAgent's logger as + ``self.logger.log_metrics(results["metrics"])`` and ``self.logger.log_plots(figs)``. + +Using an Evaluator +------------------ + +.. code-block:: python + + # How to create a new evaluator: + from gflownet.evaluator.base import GFlowNetEvaluator + + gfn_run_dir = "PUT_YOUR_RUN_DIR_HERE" # a run dir contains a .hydra folder + gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) + results = gfne.eval() + + for name, metric in results["metrics"].items(): + print(f"{name:20}: {metric:.4f}") + + data = results.get("data", {}) + + plots = gfne.plot(**data) + + print( + "Available figures in plots:", + ", ".join([fname for fname, fig in plots.items() if fig is not None]) + or "None", + ) + +Implementing your own evaluator +------------------------------- + +.. code-block:: python + + # gflownet/evaluator/my_evaluator.py + from gflownet.evaluator.base import GFlowNetEvaluator, METRICS, ALL_REQS + + class MyEvaluator(GFlowNetEvaluator): + def update_all_metrics_and_requirements(self): + global METRICS, ALL_REQS + + METRICS["my_custom_metric"] = { + "display_name": "My custom metric", + "requirements": ["density", "new_req"], + } + + ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) + + + def my_custom_metric(self, some, arguments): + intermediate = some + arguments + + return { + "metrics": { + "my_custom_metric": intermediate ** (-0.5) + }, + "data": { + "some_other": some ** 2, + "arguments": arguments, + "intermediate": intermediate, + } + } + ... + + def my_custom_plot( + self, some_other=None, arguments=None, intermediate=None, **kwargs + ): + # whatever gets to **kwargs will be ignored, this is used to handle + # methods with varying signatures. + figs = {} + if some_other is not None: + f = plt.figure() + # some plotting procedure for some_other + figs["My Title"] = f + + if arguments is not None: + f = plt.figure() + # some other plotting procedure with both + figs["My Other Title"] = f + elif arguments is not None: + f = plt.figure() + # some other plotting procedure with arguments + figs["My 3rd Title"] = f + + if intermediate is not None: + f = plt.figure() + # some other plotting procedure with intermediate + figs["My 4th Title"] = f + + return figs + + def plot(self, **kwargs): + figs = super().plot(**kwargs) + figs.update(self.my_custom_plot(**kwargs)) + + return figs + + def eval(self, metrics=None, **plot_kwargs): + gfn = self.gfn_agent + metrics = self.make_metrics(metrics) + reqs = self.make_requirements(metrics=metrics) + + results = super().eval(metrics=metrics, **plot_kwargs) + + if "new_req" in reqs: + my_results = self.my_custom_metric(some, arguments) + results["metrics"].update(my_results.get("metrics", {})) + results["data"].update(my_results.get("data", {})) + + return results + +In the previous example, the `update_all_metrics_and_requirements` method is used to +update the global `METRICS` and `ALL_REQS` variables. It will be called when the +`MyEvaluator` class is instantiated, in the init of `BaseEvaluator`. + +By defining a new requirement, you ensure that the new metrics and plots will only be +computed if user asks for a metric that requires such computations. + +By default, the train loop will call +:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` which itself calls +:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` so if you override ``eval()`` +as above, the new metrics and plots will be computed and logged. + +Similarly, `eval_and_log` will compute the ``dict`` of figures as +``fig_dict = self.plot(**results["data"])`` where ``results`` is the output of ``eval``. +""" diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 77a435305..b7dcb9dcb 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -1,8 +1,10 @@ """ Abstract evaluator class for GFlowNetAgent. -Should not be used directly, but subclassed to implement specific evaluators for -different tasks and environments. +.. warning:: + + Should not be used directly, but subclassed to implement specific evaluators for + different tasks and environments. See :py:class:`~gflownet.evaluator.base.GFlowNetEvaluator` for a default, concrete implementation of this abstract class. diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index f3a9087c5..7e1016ea6 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -10,145 +10,6 @@ and :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_agent` class methods should be used to instantiate this class. -Create a new evaluator by subclassing this class and extending the :py:meth:`eval` -method to add more metrics and plots. - -Typical call stack: - -1. :py:meth:`gflownet.gflownet.GFlowNetAgent.train` calls the evaluator's - -2. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.should_eval`. - If it returns ``True`` then :py:meth:`~gflownet.gflownet.GFlowNetAgent.train` calls - -3. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` which itself calls - -4. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` as - ``results = self.eval(metrics=None)`` and then - ``figs = self.plot(**results["data"])`` - -5. finally, :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` logs the - results using the GFlowNetAgent's logger as - ``self.logger.log_metrics(results["metrics"])`` and ``self.logger.log_plots(figs)``. - -Example -------- - -.. code-block:: python - - # How to create a new evaluator: - from gflownet.evaluator.base import GFlowNetEvaluator - - gfn_run_dir = "PUT_YOUR_RUN_DIR_HERE" # a run dir contains a .hydra folder - gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) - results = gfne.eval() - - for name, metric in results["metrics"].items(): - print(f"{name:20}: {metric:.4f}") - - data = results.get("data", {}) - - plots = gfne.plot(**data) - - print( - "Available figures in plots:", - ", ".join([fname for fname, fig in plots.items() if fig is not None]) - or "None", - ) - - -.. code-block:: python - - # gflownet/evaluator/my_evaluator.py - from gflownet.evaluator.base import GFlowNetEvaluator, METRICS, ALL_REQS - - class MyEvaluator(GFlowNetEvaluator): - def update_all_metrics_and_requirements(self): - global METRICS, ALL_REQS - - METRICS["my_custom_metric"] = { - "display_name": "My custom metric", - "requirements": ["density", "new_req"], - } - - ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) - - - def my_custom_metric(self, some, arguments): - intermediate = some + arguments - - return { - "metrics": { - "my_custom_metric": intermediate ** (-0.5) - }, - "data": { - "some_other": some ** 2, - "arguments": arguments, - "intermediate": intermediate, - } - } - ... - - def my_custom_plot( - self, some_other=None, arguments=None, intermediate=None, **kwargs - ): - # whatever gets to **kwargs will be ignored, this is used to handle - # methods with varying signatures. - figs = {} - if some_other is not None: - f = plt.figure() - # some plotting procedure for some_other - figs["My Title"] = f - - if arguments is not None: - f = plt.figure() - # some other plotting procedure with both - figs["My Other Title"] = f - elif arguments is not None: - f = plt.figure() - # some other plotting procedure with arguments - figs["My 3rd Title"] = f - - if intermediate is not None: - f = plt.figure() - # some other plotting procedure with intermediate - figs["My 4th Title"] = f - - return figs - - def plot(self, **kwargs): - figs = super().plot(**kwargs) - figs.update(self.my_custom_plot(**kwargs)) - - return figs - - def eval(self, metrics=None, **plot_kwargs): - gfn = self.gfn_agent - metrics = self.make_metrics(metrics) - reqs = self.make_requirements(metrics=metrics) - - results = super().eval(metrics=metrics, **plot_kwargs) - - if "new_req" in reqs: - my_results = self.my_custom_metric(some, arguments) - results["metrics"].update(my_results.get("metrics", {})) - results["data"].update(my_results.get("data", {})) - - return results - -In the previous example, the `update_all_metrics_and_requirements` method is used to -update the global `METRICS` and `ALL_REQS` variables. It will be called when the -`MyEvaluator` class is instantiated, in the init of `BaseEvaluator`. - -By defining a new requirement, you ensure that the new metrics and plots will only be -computed if user asks for a metric that requires such computations. - -By default, the train loop will call -:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` which itself calls -:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` so if you override ``eval()`` -as above, the new metrics and plots will be computed and logged. - -Similarly, `eval_and_log` will compute the ``dict`` of figures as -``fig_dict = self.plot(**results["data"])`` where ``results`` is the output of ``eval``. """ import pickle From 473edf9a75bfb0dc7e3cb59dd5ebf1f1e18f2cde Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 4 Mar 2024 14:09:29 -0500 Subject: [PATCH 070/106] allow `init` instantiation + more tutorial --- config/eval/base.yaml | 2 + gflownet/evaluator/__init__.py | 106 ++++++++++++++- gflownet/evaluator/abstract.py | 230 +++++++++++++-------------------- gflownet/evaluator/base.py | 193 ++++++++++++++++++++++----- gflownet/gflownet.py | 11 +- gflownet/utils/common.py | 9 +- main.py | 5 +- 7 files changed, 370 insertions(+), 186 deletions(-) diff --git a/config/eval/base.yaml b/config/eval/base.yaml index 720f65dd4..a679fdda7 100644 --- a/config/eval/base.yaml +++ b/config/eval/base.yaml @@ -1,3 +1,5 @@ +_target_: gflownet.evaluator.base.GFlowNetEvaluator + # config formerly from logger.test first_it: True period: 100 diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index 0afeb7361..d0a8cdf8d 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -61,6 +61,11 @@ class methods should be used to instantiate this class. class MyEvaluator(GFlowNetEvaluator): def update_all_metrics_and_requirements(self): + ''' + This method is called when the class is instantiated and is used to update + the global METRICS and ALL_REQS variables. It is used to define new metrics: + their display names (when logged) and requirements. + ''' global METRICS, ALL_REQS METRICS["my_custom_metric"] = { @@ -72,6 +77,28 @@ def update_all_metrics_and_requirements(self): def my_custom_metric(self, some, arguments): + ''' + Your metric-computing method. It should return a dict with two keys: + "metrics" and "data". + + The "metrics" key should contain the new metric(s) and the "data" key + should contain the intermediate results that can be used to plot the + new metric(s). + + Its arguments will come from the `eval()` method below. + + Parameters + ---------- + some : type + description + arguments : type + description + + Returns + ------- + dict + A dict with two keys: "metrics" and "data". + ''' intermediate = some + arguments return { @@ -89,6 +116,29 @@ def my_custom_metric(self, some, arguments): def my_custom_plot( self, some_other=None, arguments=None, intermediate=None, **kwargs ): + ''' + Your plotting method. + + It should return a dict with figure titles as keys and the figures as + values. + + Its arguments will come from the `plot()` method below, and basically come + from the "data" key of the output of other metrics-computing functions. + + Parameters + ---------- + some_other : type, optional + description, by default None + arguments : type, optional + description, by default None + intermediate : type, optional + description, by default None + + Returns + ------- + dict + A dict with figure titles as keys and the figures as values. + ''' # whatever gets to **kwargs will be ignored, this is used to handle # methods with varying signatures. figs = {} @@ -114,28 +164,76 @@ def my_custom_plot( return figs def plot(self, **kwargs): + ''' + Your custom plot method. + + It should return a dict with figure titles as keys and the figures as + values. + + It will be called by the `eval_and_log` method to log the figures, + and given the "data" key of the output of other metrics-computing functions. + + Returns + ------- + dict + A dict with figure titles as keys and the figures as values. + ''' figs = super().plot(**kwargs) figs.update(self.my_custom_plot(**kwargs)) return figs def eval(self, metrics=None, **plot_kwargs): - gfn = self.gfn_agent + ''' + Your custom eval method. + + It should return a dict with two keys: "metrics" and "data". + + It will be called by the `eval_and_log` method to log the metrics, + + Parameters + ---------- + metrics : Union[list, dict], optional + The metrics you want to compute in this evaluation procedure, + by default None, meaning the ones defined in the config file. + + Returns + ------- + dict + A dict with two keys: "metrics" and "data". + ''' metrics = self.make_metrics(metrics) reqs = self.make_requirements(metrics=metrics) results = super().eval(metrics=metrics, **plot_kwargs) if "new_req" in reqs: + some = self.gfn.sample_something() + arguments = utils.some_other_function() my_results = self.my_custom_metric(some, arguments) results["metrics"].update(my_results.get("metrics", {})) results["data"].update(my_results.get("data", {})) return results -In the previous example, the `update_all_metrics_and_requirements` method is used to -update the global `METRICS` and `ALL_REQS` variables. It will be called when the -`MyEvaluator` class is instantiated, in the init of `BaseEvaluator`. +Then define your own ``evaluator`` in the config file: + +.. code-block:: yaml + + # gflownet/config/evaluator/my_evaluator.yaml + defaults: + - base + + _target_: gflownet.evaluator.my_evaluator.MyEvaluator + + # any other params hereafter will extend or override the base class params: + + period: 1000 + + +In the previous example, the ``update_all_metrics_and_requirements`` method is used to +update the global ``METRICS`` and ``ALL_REQS`` variables. It will be called when the +``MyEvaluator`` class is instantiated, in the init of ``BaseEvaluator``. By defining a new requirement, you ensure that the new metrics and plots will only be computed if user asks for a metric that requires such computations. diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index b7dcb9dcb..219a761d8 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -17,10 +17,9 @@ .. code-block:: python def eval_and_log(self, it, metrics=None): - gfn = self.gfn_agent results = self.eval(metrics=metrics) for m, v in results["metrics"].items(): - setattr(gfn, m, v) + setattr(self.gfn, m, v) mertics_to_log = { METRICS[k]["display_name"]: v for k, v in results["metrics"].items() @@ -28,20 +27,20 @@ def eval_and_log(self, it, metrics=None): figs = self.plot(**results["data"]) - self.logger.log_metrics(mertics_to_log, it, gfn.use_context) - self.logger.log_plots(figs, it, use_context=gfn.use_context) + self.logger.log_metrics(mertics_to_log, it, self.gfn.use_context) + self.logger.log_plots(figs, it, use_context=self.gfn.use_context) + +See :py:mod:`gflownet.evaluator` for a full-fledged example and +:py:mod:`gflownet.evaluator.base` for a concrete implementation of this abstract class. """ -import copy import os -import time from abc import ABCMeta, abstractmethod from typing import Union -import torch +from omegaconf import OmegaConf -from gflownet.utils.batch import Batch -from gflownet.utils.common import batch_with_rest, load_gflow_net_from_run_path +from gflownet.utils.common import load_gflow_net_from_run_path _sentinel = object() @@ -91,6 +90,10 @@ def eval_and_log(self, it, metrics=None): metric. Display names are used to log the metrics and to display them in the console. + +Implementations of :py:class:`GFlowNetAbstractEvaluator` should update +this dict and the :py:const:`ALL_REQS` set to include new metrics by implementing the +:py:method:`update_all_metrics_and_requirements` method. """ ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) @@ -101,15 +104,38 @@ def eval_and_log(self, it, metrics=None): class GFlowNetAbstractEvaluator(metaclass=ABCMeta): - def __init__(self, **kwargs): + def __init__(self, gfn_agent=None, **config): """ Base evaluator class for GFlowNetAgent. In charge of evaluating the GFlowNetAgent, computing metrics plotting figures and optionally logging results using the GFlowNetAgent's logger. - Only the `from_dir` and `from_agent` class methods should be used to instantiate - this class. + You can use the :py:method:`from_dir` or :py:method:`from_agent` class methods + to easily instantiate this class from a run directory or an existing + in-memory ``GFlowNetAgent``. + + Use :py:method:`set_agent` to set the evaluator's ``GFlowNetAgent`` after + initialization if it was not provided at instantiation as ``gfn_agent=``. + + This ``init`` function will call, in order: + + 1. :py:method:`update_all_metrics_and_requirements` + + 2. ``self.metrics = self.make_metrics(self.config.metrics)`` using + :py:method:`make_metrics` + + 3. ``self.reqs = self.make_requirements()`` using :py:method:`make_requirements` + + Arguments + --------- + gfn_agent : GFlowNetAgent, optional + The GFlowNetAgent to evaluate. By default None. Should be set using the + :py:method:`from_dir` or :py:method:`from_agent` class methods. + + config : dict + The configuration of the evaluator. Will be converted to an OmegaConf + instance and stored in the ``self.config`` attribute. Raises ------ @@ -118,24 +144,59 @@ def __init__(self, **kwargs): prevent instantiation of the base class without using the `from_dir` or `from_agent` class methods. + Attributes + ---------- + config : OmegaConf + The configuration of the evaluator. + metrics : dict + Dictionary of metrics to compute, with the metric names as keys and the + metric display names and requirements as values. + reqs : set[str] + The set of requirements for the metrics. Used to decide which kind of data / + samples is required to compute the metric. + logger : Logger + The logger to use to log the results of the evaluation. Will be set to the + GFlowNetAgent's logger. + gfn: :py:class:`GFlowNetAgent` + The GFlowNetAgent to evaluate. """ - if kwargs.get("sentinel") is not _sentinel: - raise NotImplementedError( - "Base evaluator class should not be instantiated. Use " - + "GFlowNetEvaluator.from_dir or GFlowNetEvaluator.from_agent methods." - ) - self.gfn_agent = kwargs.get("gfn_agent") - self.config = self.gfn_agent.eval_config - self.logger = self.gfn_agent.logger - self.reqs = set() + + self._gfn_agent = gfn_agent + self.config = OmegaConf.create(config) + + if self._gfn_agent is not None: + self.logger = self._gfn_agent.logger self.metrics = self.reqs = _sentinel self.update_all_metrics_and_requirements() - self.metrics = self.make_metrics(self.config.metrics) self.reqs = self.make_requirements() + @property + def gfn(self): + if type(self._gfn_agent).__name__ != "GFlowNetAgent": + raise ValueError( + "The GFlowNetAgent has not been set. Use the `from_dir` or `from_agent`" + + " class methods to instantiate this class or set the `gfn` attribute." + ) + return self._gfn_agent + + def set_agent(self, gfn_agent): + assert type(gfn_agent).__name__ == "GFlowNetAgent", ( + "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " + + f"{type(gfn_agent)}." + ) + self._gfn_agent = gfn_agent + self.logger = gfn_agent.logger + + @gfn.setter + def gfn(self, _): + raise AttributeError( + "The `gfn` attribute is read-only. Use the `set_agent` method to set the" + + " GFlowNetAgent." + ) + def update_all_metrics_and_requirements(self): """ Method to be implemented by subclasses to update the global dict of metrics and @@ -182,7 +243,7 @@ def from_dir( device=device, load_final_ckpt=load_final_ckpt, ) - return GFlowNetAbstractEvaluator.from_agent(gfn_agent) + return cls.from_agent(gfn_agent) @classmethod def from_agent(cls, gfn_agent): @@ -208,7 +269,7 @@ def from_agent(cls, gfn_agent): + f"{type(gfn_agent)}." ) - return GFlowNetAbstractEvaluator(gfn_agent=gfn_agent, sentinel=_sentinel) + return cls(gfn_agent=gfn_agent, **gfn_agent.evaluator.config) def make_metrics(self, metrics=None): """ @@ -467,6 +528,10 @@ def plot(self, **kwargs): def eval(self, metrics=None, **plot_kwargs): pass + @abstractmethod + def eval_top_k(self, it): + pass + def eval_and_log(self, it, metrics=None): """ Evaluate the GFlowNetAgent and log the results with its logger. @@ -481,10 +546,9 @@ def eval_and_log(self, it, metrics=None): metrics : Union[str, List[str]], optional List of metrics to compute, by default the evaluator's `metrics` attribute. """ - gfn = self.gfn_agent results = self.eval(metrics=metrics) for m, v in results["metrics"].items(): - setattr(gfn, m, v) + setattr(self.gfn, m, v) mertics_to_log = { METRICS[k]["display_name"]: v for k, v in results["metrics"].items() @@ -492,118 +556,8 @@ def eval_and_log(self, it, metrics=None): figs = self.plot(**results["data"]) - self.logger.log_metrics(mertics_to_log, it, gfn.use_context) - self.logger.log_plots(figs, it, use_context=gfn.use_context) - - @torch.no_grad() - def eval_top_k(self, it, gfn_states=None, random_states=None): - """ - Sample from the current GFN and compute metrics and plots for the top k states - according to both the energy and the reward. - - Parameters - ---------- - it : int - current iteration - gfn_states : list, optional - Already sampled gfn states. Defaults to None. - random_states : list, optional - Already sampled random states. Defaults to None. - - Returns - ------- - dict - Computed dict of metrics, and figures, and optionally (only once) summary - metrics. Schema: ``{"metrics": {str: float}, "figs": {str: plt.Figure}, - "summary": {str: float}}``. - """ - # only do random top k plots & metrics once - do_random = it // self.logger.test.top_k_period == 1 - duration = None - summary = {} - prob = copy.deepcopy(self.random_action_prob) - gfn = self.gfn_agent - print() - if not gfn_states: - # sample states from the current gfn - batch = Batch(env=gfn.env, device=gfn.device, float_type=gfn.float) - gfn.random_action_prob = 0 - t = time.time() - print("Sampling from GFN...", end="\r") - for b in batch_with_rest(0, gfn.logger.test.n_top_k, gfn.batch_size_total): - sub_batch, _ = gfn.sample_batch(n_forward=len(b), train=False) - batch.merge(sub_batch) - duration = time.time() - t - gfn_states = batch.get_terminating_states() - - # compute metrics and get plots - print("[eval_top_k] Making GFN plots...", end="\r") - metrics, figs, fig_names = gfn.env.top_k_metrics_and_plots( - gfn_states, gfn.logger.test.top_k, name="gflownet", step=it - ) - if duration: - metrics["gflownet top k sampling duration"] = duration - - if do_random: - # sample random states from uniform actions - if not random_states: - batch = Batch(env=gfn.env, device=gfn.device, float_type=gfn.float) - gfn.random_action_prob = 1.0 - print("[eval_top_k] Sampling at random...", end="\r") - for b in batch_with_rest( - 0, gfn.logger.test.n_top_k, gfn.batch_size_total - ): - sub_batch, _ = gfn.sample_batch(n_forward=len(b), train=False) - batch.merge(sub_batch) - # compute metrics and get plots - random_states = batch.get_terminating_states() - print("[eval_top_k] Making Random plots...", end="\r") - ( - random_metrics, - random_figs, - random_fig_names, - ) = gfn.env.top_k_metrics_and_plots( - random_states, gfn.logger.test.top_k, name="random", step=None - ) - # add to current metrics and plots - summary.update(random_metrics) - figs += random_figs - fig_names += random_fig_names - # compute training data metrics and get plots - print("[eval_top_k] Making train plots...", end="\r") - ( - train_metrics, - train_figs, - train_fig_names, - ) = gfn.env.top_k_metrics_and_plots( - None, gfn.logger.test.top_k, name="train", step=None - ) - # add to current metrics and plots - summary.update(train_metrics) - figs += train_figs - fig_names += train_fig_names - - gfn.random_action_prob = prob - - print(" " * 100, end="\r") - print("eval_top_k metrics:") - max_k = max([len(k) for k in (list(metrics.keys()) + list(summary.keys()))]) + 1 - print( - " • " - + "\n • ".join( - f"{k:{max_k}}: {v:.4f}" - for k, v in (list(metrics.items()) + list(summary.items())) - ) - ) - print() - - figs = {f: n for f, n in zip(figs, fig_names)} - - return { - "metrics": metrics, - "figs": figs, - "summary": summary, - } + self.logger.log_metrics(mertics_to_log, it, self.gfn.use_context) + self.logger.log_plots(figs, it, use_context=self.gfn.use_context) def eval_and_log_top_k(self, it): """ diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 7e1016ea6..4bec305b3 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -1,9 +1,11 @@ """ Base evaluator class for GFlowNetAgent. -In charge of evaluating the GFlowNetAgent, computing metrics plotting figures and +In charge of evaluating a generic GFlowNetAgent, computing metrics plotting figures and optionally logging results using the GFlowNetAgent's logger. +Take it as example to implement your own evaluator class for your custom use-case. + .. important:: Only the :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_dir` @@ -12,7 +14,9 @@ class methods should be used to instantiate this class. """ +import copy import pickle +import time from collections import defaultdict import numpy as np @@ -22,17 +26,132 @@ class methods should be used to instantiate this class. from gflownet.evaluator.abstract import ALL_REQS # noqa from gflownet.evaluator.abstract import METRICS # noqa from gflownet.evaluator.abstract import GFlowNetAbstractEvaluator -from gflownet.utils.common import tfloat, torch2np +from gflownet.utils.batch import Batch +from gflownet.utils.common import batch_with_rest, tfloat, torch2np class GFlowNetEvaluator(GFlowNetAbstractEvaluator): + @torch.no_grad() + def eval_top_k(self, it, gfn_states=None, random_states=None): + """ + Sample from the current GFN and compute metrics and plots for the top k states + according to both the energy and the reward. + + Parameters + ---------- + it : int + current iteration + gfn_states : list, optional + Already sampled gfn states. Defaults to None. + random_states : list, optional + Already sampled random states. Defaults to None. + + Returns + ------- + dict + Computed dict of metrics, and figures, and optionally (only once) summary + metrics. Schema: ``{"metrics": {str: float}, "figs": {str: plt.Figure}, + "summary": {str: float}}``. + """ + # only do random top k plots & metrics once + do_random = it // self.logger.test.top_k_period == 1 + duration = None + summary = {} + prob = copy.deepcopy(self.random_action_prob) + print() + if not gfn_states: + # sample states from the current gfn + batch = Batch( + env=self.gfn.env, device=self.gfn.device, float_type=self.gfn.float + ) + self.gfn.random_action_prob = 0 + t = time.time() + print("Sampling from GFN...", end="\r") + for b in batch_with_rest( + 0, self.gfn.logger.test.n_top_k, self.gfn.batch_size_total + ): + sub_batch, _ = self.gfn.sample_batch(n_forward=len(b), train=False) + batch.merge(sub_batch) + duration = time.time() - t + gfn_states = batch.get_terminating_states() + + # compute metrics and get plots + print("[eval_top_k] Making GFN plots...", end="\r") + metrics, figs, fig_names = self.gfn.env.top_k_metrics_and_plots( + gfn_states, self.gfn.logger.test.top_k, name="gflownet", step=it + ) + if duration: + metrics["gflownet top k sampling duration"] = duration + + if do_random: + # sample random states from uniform actions + if not random_states: + batch = Batch( + env=self.gfn.env, device=self.gfn.device, float_type=self.gfn.float + ) + self.gfn.random_action_prob = 1.0 + print("[eval_top_k] Sampling at random...", end="\r") + for b in batch_with_rest( + 0, self.gfn.logger.test.n_top_k, self.gfn.batch_size_total + ): + sub_batch, _ = self.gfn.sample_batch(n_forward=len(b), train=False) + batch.merge(sub_batch) + # compute metrics and get plots + random_states = batch.get_terminating_states() + print("[eval_top_k] Making Random plots...", end="\r") + ( + random_metrics, + random_figs, + random_fig_names, + ) = self.gfn.env.top_k_metrics_and_plots( + random_states, self.gfn.logger.test.top_k, name="random", step=None + ) + # add to current metrics and plots + summary.update(random_metrics) + figs += random_figs + fig_names += random_fig_names + # compute training data metrics and get plots + print("[eval_top_k] Making train plots...", end="\r") + ( + train_metrics, + train_figs, + train_fig_names, + ) = self.gfn.env.top_k_metrics_and_plots( + None, self.gfn.logger.test.top_k, name="train", step=None + ) + # add to current metrics and plots + summary.update(train_metrics) + figs += train_figs + fig_names += train_fig_names + + self.gfn.random_action_prob = prob + + print(" " * 100, end="\r") + print("eval_top_k metrics:") + max_k = max([len(k) for k in (list(metrics.keys()) + list(summary.keys()))]) + 1 + print( + " • " + + "\n • ".join( + f"{k:{max_k}}: {v:.4f}" + for k, v in (list(metrics.items()) + list(summary.items())) + ) + ) + print() + + figs = {f: n for f, n in zip(figs, fig_names)} + + return { + "metrics": metrics, + "figs": figs, + "summary": summary, + } + def compute_log_prob_metrics(self, x_tt, metrics=None): - gfn = self.gfn_agent metrics = self.make_metrics(metrics) reqs = self.make_requirements(metrics=metrics) - logprobs_x_tt, logprobs_std, probs_std = gfn.estimate_logprobs_data( + logprobs_x_tt, logprobs_std, probs_std = self.gfn.estimate_logprobs_data( x_tt, n_trajectories=self.config.n_trajs_logprobs, max_data_size=self.config.max_data_logprobs, @@ -49,19 +168,23 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): lp_metrics["mean_probs_std"] = probs_std.mean().item() if "reward_batch" in reqs: - rewards_x_tt = gfn.env.reward_batch(x_tt) + rewards_x_tt = self.gfn.env.reward_batch(x_tt) if "corr_prob_traj_rewards" in metrics: - rewards_x_tt = gfn.env.reward_batch(x_tt) + rewards_x_tt = self.gfn.env.reward_batch(x_tt) lp_metrics["corr_prob_traj_rewards"] = np.corrcoef( np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt )[0, 1] if "var_logrewards_logp" in metrics: - rewards_x_tt = gfn.env.reward_batch(x_tt) + rewards_x_tt = self.gfn.env.reward_batch(x_tt) lp_metrics["var_logrewards_logp"] = torch.var( torch.log( - tfloat(rewards_x_tt, float_type=gfn.float, device=gfn.device) + tfloat( + rewards_x_tt, + float_type=self.gfn.float, + device=self.gfn.device, + ) ) - logprobs_x_tt ).item() @@ -78,7 +201,6 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): } def compute_density_metrics(self, x_tt, dict_tt, metrics=None): - gfn = self.gfn_agent metrics = self.make_metrics(metrics) density_metrics = {} @@ -86,18 +208,18 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): x_sampled = density_true = density_pred = None - if gfn.buffer.test_type is not None and gfn.buffer.test_type == "all": - batch, _ = gfn.sample_batch(n_forward=self.config.n, train=False) + if self.gfn.buffer.test_type is not None and self.gfn.buffer.test_type == "all": + batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False) assert batch.is_valid() x_sampled = batch.get_terminating_states() if "density_true" in dict_tt: density_true = dict_tt["density_true"] else: - rewards = gfn.env.reward_batch(x_tt) + rewards = self.gfn.env.reward_batch(x_tt) z_true = rewards.sum() density_true = rewards / z_true - with open(gfn.buffer.test_pkl, "wb") as f: + with open(self.gfn.buffer.test_pkl, "wb") as f: dict_tt["density_true"] = density_true pickle.dump(dict_tt, f) hist = defaultdict(int) @@ -108,14 +230,14 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): log_density_true = np.log(density_true + 1e-8) log_density_pred = np.log(density_pred + 1e-8) - elif gfn.continuous and hasattr(gfn.env, "fit_kde"): - batch, _ = gfn.sample_batch(n_forward=self.config.n, train=False) + elif self.gfn.continuous and hasattr(self.gfn.env, "fit_kde"): + batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False) assert batch.is_valid() x_sampled = batch.get_terminating_states() # TODO make it work with conditional env - x_sampled = torch2np(gfn.env.states2proxy(x_sampled)) - x_tt = torch2np(gfn.env.states2proxy(x_tt)) - kde_pred = gfn.env.fit_kde( + x_sampled = torch2np(self.gfn.env.states2proxy(x_sampled)) + x_tt = torch2np(self.gfn.env.states2proxy(x_tt)) + kde_pred = self.gfn.env.fit_kde( x_sampled, kernel=self.config.kde.kernel, bandwidth=self.config.kde.bandwidth, @@ -125,10 +247,10 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): kde_true = dict_tt["kde_true"] else: # Sample from reward via rejection sampling - x_from_reward = gfn.env.sample_from_reward(n_samples=self.config.n) - x_from_reward = torch2np(gfn.env.states2proxy(x_from_reward)) + x_from_reward = self.gfn.env.sample_from_reward(n_samples=self.config.n) + x_from_reward = torch2np(self.gfn.env.states2proxy(x_from_reward)) # Fit KDE with samples from reward - kde_true = gfn.env.fit_kde( + kde_true = self.gfn.env.fit_kde( x_from_reward, kernel=self.config.kde.kernel, bandwidth=self.config.kde.bandwidth, @@ -138,7 +260,7 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): scores_true = kde_true.score_samples(x_tt) log_density_true = scores_true - logsumexp(scores_true, axis=0) # Add log_density_true and kde_true to pickled test dict - with open(gfn.buffer.test_pkl, "wb") as f: + with open(self.gfn.buffer.test_pkl, "wb") as f: dict_tt["log_density_true"] = log_density_true dict_tt["kde_true"] = kde_true pickle.dump(dict_tt, f) @@ -153,9 +275,9 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): density_data["kde_true"] = kde_true else: - density_metrics["l1"] = gfn.l1 - density_metrics["kl"] = gfn.kl - density_metrics["jsd"] = gfn.jsd + density_metrics["l1"] = self.gfn.l1 + density_metrics["kl"] = self.gfn.kl + density_metrics["jsd"] = self.gfn.jsd density_data["x_sampled"] = x_sampled return { "metrics": density_metrics, @@ -215,14 +337,14 @@ def eval(self, metrics=None, **plot_kwargs): Computed dict of metrics and figures as `{"metrics": {str: float}, "figs": {str: plt.Figure}}`. """ - gfn = self.gfn_agent metrics = self.make_metrics(metrics) reqs = self.make_requirements(metrics=metrics) - if gfn.buffer.test_pkl is None: + if self.gfn.buffer.test_pkl is None: return { "metrics": { - k: getattr(gfn, k) if hasattr(gfn, k) else None for k in metrics + k: getattr(self.gfn, k) if hasattr(self.gfn, k) else None + for k in metrics }, "data": {}, } @@ -230,7 +352,7 @@ def eval(self, metrics=None, **plot_kwargs): all_data = {} all_metrics = {} - with open(gfn.buffer.test_pkl, "rb") as f: + with open(self.gfn.buffer.test_pkl, "rb") as f: dict_tt = pickle.load(f) x_tt = dict_tt["x"] @@ -293,18 +415,19 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): Dictionary of figures to be logged. The keys are the figure names and the values are the figures. """ - gfn = self.gfn_agent fig_kde_pred = fig_kde_true = fig_reward_samples = None - if hasattr(gfn.env, "plot_reward_samples") and x_sampled is not None: - fig_reward_samples = gfn.env.plot_reward_samples(x_sampled, **plot_kwargs) + if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None: + fig_reward_samples = self.gfn.env.plot_reward_samples( + x_sampled, **plot_kwargs + ) - if hasattr(gfn.env, "plot_kde"): + if hasattr(self.gfn.env, "plot_kde"): if kde_pred is not None: - fig_kde_pred = gfn.env.plot_kde(kde_pred, **plot_kwargs) + fig_kde_pred = self.gfn.env.plot_kde(kde_pred, **plot_kwargs) if kde_true is not None: - fig_kde_true = gfn.env.plot_kde(kde_true, **plot_kwargs) + fig_kde_true = self.gfn.env.plot_kde(kde_true, **plot_kwargs) return { "True reward and GFlowNet samples": fig_reward_samples, diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 8c7f3dbad..a5d23fe91 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -47,7 +47,7 @@ def __init__( pct_offline, logger, num_empirical_loss, - eval_config, + evaluator, state_flow=None, use_context=False, replay_sampling="permutation", @@ -94,9 +94,8 @@ def __init__( (`gflownet/utils/logger.py:Logger`). num_empirical_loss : int Number of empirical loss samples to be used for training. - eval_config : dict, optional - Evaluator config dictionary. See `eval/base.yaml` for details. By default - None. + evaluator : gflownet.evaluator.base.GFlowNetEvaluator + :py:mod:`~gflownet.evaluator` ``Evaluator`` instance. state_flow : dict, optional State flow config dictionary. See `gflownet.yaml:state_flow` for details. By default None. @@ -229,8 +228,8 @@ def __init__( self.opt, self.lr_scheduler, self.target = None, None, None # Evaluator - self.eval_config = eval_config - self.evaluator = GFlowNetEvaluator.from_agent(self) + self.evaluator = evaluator + self.evaluator.set_agent(self) self.n_train_steps = optimizer.n_train_steps self.batch_size = optimizer.batch_size diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index c3e34d6aa..fffe8cc9b 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -1,5 +1,6 @@ import os import random +from copy import deepcopy from os.path import expandvars from pathlib import Path from typing import List, Union @@ -8,7 +9,7 @@ import torch from hydra import compose, initialize_config_dir from hydra.utils import get_original_cwd, instantiate -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from torchtyping import TensorType from gflownet.utils.policy import parse_policy_config @@ -257,6 +258,9 @@ def gflownet_from_config(config): float_precision=config.float_precision, ) + # The evaluator is used to compute metrics and plots + evaluator = instantiate(config.eval) + # The policy is used to model the probability of a forward/backward action forward_config = parse_policy_config(config, kind="forward") backward_config = parse_policy_config(config, kind="backward") @@ -298,7 +302,8 @@ def gflownet_from_config(config): state_flow=state_flow, buffer=config.env.buffer, logger=logger, - eval_config=config.eval, + evaluator=evaluator, + full_config=config, ) return gflownet diff --git a/main.py b/main.py index fdfff3d43..036d9fc45 100644 --- a/main.py +++ b/main.py @@ -45,6 +45,8 @@ def main(config): device=config.device, float_precision=config.float_precision, ) + # The evaluator is used to compute metrics and plots + evaluator = hydra.utils.instantiate(config.eval) # The policy is used to model the probability of a forward/backward action forward_config = parse_policy_config(config, kind="forward") backward_config = parse_policy_config(config, kind="backward") @@ -74,6 +76,7 @@ def main(config): else: state_flow = None # GFlowNet Agent + gflownet = hydra.utils.instantiate( config.gflownet, device=config.device, @@ -84,7 +87,7 @@ def main(config): state_flow=state_flow, buffer=config.env.buffer, logger=logger, - eval_config=config.eval, + evaluator=evaluator, ) # Train GFlowNet From 670bdb8c9038e3c38b411b71a3aac07e250f9221 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 4 Mar 2024 14:23:32 -0500 Subject: [PATCH 071/106] `define_new_metrics` --- gflownet/evaluator/__init__.py | 23 ++++++++++------------ gflownet/evaluator/abstract.py | 35 ++++++++++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index d0a8cdf8d..e316c968f 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -60,21 +60,18 @@ class methods should be used to instantiate this class. from gflownet.evaluator.base import GFlowNetEvaluator, METRICS, ALL_REQS class MyEvaluator(GFlowNetEvaluator): - def update_all_metrics_and_requirements(self): + def define_new_metrics(self): ''' This method is called when the class is instantiated and is used to update - the global METRICS and ALL_REQS variables. It is used to define new metrics: - their display names (when logged) and requirements. + the global METRICS and ALL_REQS variables. ''' - global METRICS, ALL_REQS - - METRICS["my_custom_metric"] = { - "display_name": "My custom metric", - "requirements": ["density", "new_req"], + return { + "your_metric": { + "display_name": "My custom metric", + "requirements": ["density", "new_req"], + }, } - ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) - def my_custom_metric(self, some, arguments): ''' @@ -231,9 +228,9 @@ def eval(self, metrics=None, **plot_kwargs): period: 1000 -In the previous example, the ``update_all_metrics_and_requirements`` method is used to -update the global ``METRICS`` and ``ALL_REQS`` variables. It will be called when the -``MyEvaluator`` class is instantiated, in the init of ``BaseEvaluator``. +In the previous example, the ``define_new_metrics`` method is used to define new +metrics and associated requirements. It will be called when the +``MyEvaluator`` class is instantiated, in the init of ``GFlowNetAbstractEvaluator``. By defining a new requirement, you ensure that the new metrics and plots will only be computed if user asks for a metric that requires such computations. diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 219a761d8..a4ec67146 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -120,7 +120,10 @@ def __init__(self, gfn_agent=None, **config): This ``init`` function will call, in order: - 1. :py:method:`update_all_metrics_and_requirements` + 1. :py:method:`.update_all_metrics_and_requirements` which uses new metrics + defined in the :py:method:`define_new_metrics` method to update the global + :py:const:`METRICS` and :py:const:`ALL_REQS` variables in classes + inheriting from :py:class:`GFlowNetAbstractEvaluator`. 2. ``self.metrics = self.make_metrics(self.config.metrics)`` using :py:method:`make_metrics` @@ -197,12 +200,40 @@ def gfn(self, _): + " GFlowNetAgent." ) + def define_new_metrics(self): + """ + Method to be implemented by subclasses to define new metrics. + + Example + ------- + .. code-block:: python + + def define_new_metrics(self): + return { + "my_custom_metric": { + "display_name": "My custom metric", + "requirements": ["density", "new_req"], + } + } + + Returns + ------- + dict + Dictionary of new metrics to add to the global `METRICS` dict. + """ + pass + def update_all_metrics_and_requirements(self): """ Method to be implemented by subclasses to update the global dict of metrics and requirements. """ - pass + new_metrics = self.define_new_metrics() + if new_metrics: + global METRICS + global ALL_REQS + METRICS.update(new_metrics) + ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) @classmethod def from_dir( From 82f07819743f3e897c0d5f5fb115d742fd1dbcc1 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 4 Mar 2024 14:27:40 -0500 Subject: [PATCH 072/106] test `.` --- gflownet/evaluator/abstract.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index a4ec67146..8fae45966 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -120,7 +120,8 @@ def __init__(self, gfn_agent=None, **config): This ``init`` function will call, in order: - 1. :py:method:`.update_all_metrics_and_requirements` which uses new metrics + 1. :py:method:`.GFlowNetAbstractEvaluator.update_all_metrics_and_requirements` + which uses new metrics defined in the :py:method:`define_new_metrics` method to update the global :py:const:`METRICS` and :py:const:`ALL_REQS` variables in classes inheriting from :py:class:`GFlowNetAbstractEvaluator`. From 34082c1a9ff37fba8d96c31eb791fb779bf5ca2f Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 4 Mar 2024 14:29:53 -0500 Subject: [PATCH 073/106] no `.` ? --- gflownet/evaluator/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 8fae45966..ba65b552c 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -120,7 +120,7 @@ def __init__(self, gfn_agent=None, **config): This ``init`` function will call, in order: - 1. :py:method:`.GFlowNetAbstractEvaluator.update_all_metrics_and_requirements` + 1. :py:method:`GFlowNetAbstractEvaluator.update_all_metrics_and_requirements` which uses new metrics defined in the :py:method:`define_new_metrics` method to update the global :py:const:`METRICS` and :py:const:`ALL_REQS` variables in classes From 43015d18090c46d4d864804f90d9a50edc132365 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 4 Mar 2024 14:44:37 -0500 Subject: [PATCH 074/106] update links --- gflownet/evaluator/abstract.py | 89 +++++++++++++++++++--------------- 1 file changed, 49 insertions(+), 40 deletions(-) diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index ba65b552c..6f82a7b84 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -115,21 +115,25 @@ def __init__(self, gfn_agent=None, **config): to easily instantiate this class from a run directory or an existing in-memory ``GFlowNetAgent``. - Use :py:method:`set_agent` to set the evaluator's ``GFlowNetAgent`` after - initialization if it was not provided at instantiation as ``gfn_agent=``. + Use + :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.set_agent` + to set the evaluator's ``GFlowNetAgent`` after initialization if it was not + provided at instantiation as ``gfn_agent=``. - This ``init`` function will call, in order: + This ``__init__`` function will call, in order: - 1. :py:method:`GFlowNetAbstractEvaluator.update_all_metrics_and_requirements` + 1. :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.update_all_metrics_and_requirements` which uses new metrics - defined in the :py:method:`define_new_metrics` method to update the global - :py:const:`METRICS` and :py:const:`ALL_REQS` variables in classes - inheriting from :py:class:`GFlowNetAbstractEvaluator`. + defined in the + :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.define_new_metrics` + method to update the global :py:const:`METRICS` and :py:const:`ALL_REQS` + variables in classes inheriting from :py:class:`GFlowNetAbstractEvaluator`. 2. ``self.metrics = self.make_metrics(self.config.metrics)`` using :py:method:`make_metrics` - 3. ``self.reqs = self.make_requirements()`` using :py:method:`make_requirements` + 3. ``self.reqs = self.make_requirements()`` using + :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.make_requirements` Arguments --------- @@ -220,7 +224,7 @@ def define_new_metrics(self): Returns ------- dict - Dictionary of new metrics to add to the global `METRICS` dict. + Dictionary of new metrics to add to the global :py:const:`METRICS` dict. """ pass @@ -305,22 +309,22 @@ def from_agent(cls, gfn_agent): def make_metrics(self, metrics=None): """ - Parse metrics from a dict, list, a string or None. + Parse metrics from a dict, list, a string or ``None``. - - If `None`, all metrics are selected. + - If ``None``, all metrics are selected. - If a string, it can be a comma-separated list of metric names, with or without spaces. - - If a list, it should be a list of metric names (keys of `METRICS`). + - If a list, it should be a list of metric names (keys of :py:const:`METRICS`). - If a dict, its keys should be metric names and its values will be ignored: - they will be assigned from `METRICS`. + they will be assigned from :py:const:`METRICS`. - All metrics must be in `METRICS`. + All metrics must be in :py:const:`METRICS`. Parameters ---------- metrics : Union[str, List[str]], optional Metrics to compute when running the `evaluator.eval()` function. Defaults to - None, i.e. all metrics in `METRICS` are computed. + None, i.e. all metrics in :py:const:`METRICS` are computed. Returns ------- @@ -331,7 +335,7 @@ def make_metrics(self, metrics=None): Raises ------ ValueError - If a metric name is not in `METRICS`. + If a metric name is not in :py:const:`METRICS`. """ if metrics is None: assert self.metrics is not _sentinel, ( @@ -376,22 +380,25 @@ def make_requirements(self, reqs=None, metrics=None): """ Make requirements for the metrics to compute. - 1. If `metrics` is provided, they must be as a dict of metrics. The requirements - are computed from the `requirements` attribute of the metrics. + 1. If ``metrics`` is provided, they must be as a dict of metrics. + The requirements are computed from the ``requirements`` attribute of + the metrics. - 2. Otherwise, the requirements are computed from the `reqs` argument: - - If `reqs` is `"all"`, all requirements of all metrics are computed. - - If `reqs` is `None`, the evaluator's `self.reqs` attribute is used. - - If `reqs` is a list, it is used as the requirements. + 2. Otherwise, the requirements are computed from the ``reqs`` argument: + - If ``reqs`` is ``"all"``, all requirements of all metrics are computed. + - If ``reqs`` is ``None``, the evaluator's ``self.reqs`` attribute is used. + - If ``reqs`` is a list, it is used as the requirements. Parameters ---------- reqs : Union[str, List[str]], optional - The metrics requirements. Either `"all"`, a list of requirements or `None` - to use the evaluator's `self.reqs` attribute. By default None + The metrics requirements. Either ``"all"``, a list of requirements or + ``None`` to use the evaluator's ``self.reqs`` attribute. + By default ``None``. metrics : Union[str, List[str], dict], optional The metrics to compute requirements for. If not a dict, will be passed to - `make_metrics`. By default None. + :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.make_metrics``. + By default None. Returns ------- @@ -448,9 +455,9 @@ def make_requirements(self, reqs=None, metrics=None): def should_log_train(self, step): """ Check if training logs should be done at the current step. The decision is based - on the `self.config.train.period` attribute. + on the ``self.config.train.period`` attribute. - Set `self.config.train.period` to `None` or a negative value to disable + Set ``self.config.train.period`` to ``None`` or a negative value to disable training. Parameters @@ -471,13 +478,14 @@ def should_log_train(self, step): def should_eval(self, step): """ Check if testing should be done at the current step. The decision is based on - the `self.config.test.period` attribute. + the ``self.config.test.period`` attribute. - Set `self.config.test.first_it` to `True` if testing should be done at the first - iteration step. Otherwise, testing will be done aftter `self.config.test.period` - steps. + Set ``self.config.test.first_it`` to ``True`` if testing should be done at the + first iteration step. Otherwise, testing will be done aftter + ``self.config.test.period`` steps. - Set `self.config.test.period` to `None` or a negative value to disable testing. + Set ``self.config.test.period`` to ``None`` or a negative value to disable + testing. Parameters ---------- @@ -499,10 +507,10 @@ def should_eval(self, step): def should_eval_top_k(self, step): """ Check if top k plots and metrics should be done at the current step. The - decision is based on the `self.config.test.top_k` and - `self.config.test.top_k_period` attributes. + decision is based on the ``self.config.test.top_k`` and + ``self.config.test.top_k_period`` attributes. - Set `self.config.test.top_k` to `None` or a negative value to disable top k + Set ``self.config.test.top_k`` to ``None`` or a negative value to disable top k plots and metrics. Parameters @@ -529,9 +537,9 @@ def should_eval_top_k(self, step): def should_checkpoint(self, step): """ Check if checkpoints should be done at the current step. The decision is based - on the `self.checkpoints.period` attribute. + on the ``self.checkpoints.period`` attribute. - Set `self.checkpoints.period` to `None` or a negative value to disable + Set ``self.checkpoints.period`` to ``None`` or a negative value to disable checkpoints. Parameters @@ -568,15 +576,16 @@ def eval_and_log(self, it, metrics=None): """ Evaluate the GFlowNetAgent and log the results with its logger. - Will call `self.eval()` and log the results using the GFlowNetAgent's logger - `log_metrics()` and `log_plots()` methods. + Will call ``self.eval()`` and log the results using the GFlowNetAgent's logger + ``log_metrics()`` and ``log_plots()`` methods. Parameters ---------- it : int Current iteration step. metrics : Union[str, List[str]], optional - List of metrics to compute, by default the evaluator's `metrics` attribute. + List of metrics to compute, by default the evaluator's ``metrics`` + attribute. """ results = self.eval(metrics=metrics) for m, v in results["metrics"].items(): From d0cdb27825de7668da321e4f5aeb983417b8bc1d Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 4 Mar 2024 17:15:57 -0500 Subject: [PATCH 075/106] always use `evaluator` --- config/{eval => evaluator}/base.yaml | 0 config/main.yaml | 2 +- gflownet/evaluator/__init__.py | 32 +++++++++++++++++++++++++- gflownet/evaluator/abstract.py | 34 ++++++++++++++++++++++++++-- 4 files changed, 64 insertions(+), 4 deletions(-) rename config/{eval => evaluator}/base.yaml (100%) diff --git a/config/eval/base.yaml b/config/evaluator/base.yaml similarity index 100% rename from config/eval/base.yaml rename to config/evaluator/base.yaml diff --git a/config/main.yaml b/config/main.yaml index 65aa6cfd1..d07fac96a 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -6,7 +6,7 @@ defaults: - proxy: corners - logger: wandb - user: default - - eval: base + - evaluator: base # Device device: cuda diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index e316c968f..da7c9e4e2 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -26,6 +26,36 @@ class methods should be used to instantiate this class. results using the GFlowNetAgent's logger as ``self.logger.log_metrics(results["metrics"])`` and ``self.logger.log_plots(figs)``. +Basic concepts +-------------- + +The evaluator is used to compute metrics and plots. It is used to evaluate the +performance of the agent during training and to log the results. It is also +intended to be used to evaluate the performance of a trained agent. + +The ``metrics`` keyword argument usually reflect to a description of which quantities +are to be computed. They can take the following forms: + +- ``None``: all metrics defined in the config file / in the evaluator's + ``.config.metrics`` attribute will be computed. + +- ``"all"``: all known metrics as defined in :py:const:`METRICS` will be computed. + + - Note that classes that inherit from :py:class:`GFlowNetEvaluator` can define new + metrics with the :py:meth:`define_new_metrics` method. + +- ``list``: a list of metric names to be computed. The names must be keys of + :py:const:`METRICS`. + +- ``dict``: a dictionary that is a subset of :py:const:`METRICS`. + +The concept of ``requirements`` is used to avoid unnecessary computations. If a metric +requires a certain quantity to be computed, then the evaluator will only compute that +quantity if the metric is requested. This is done by the :py:meth:`make_requirements` +method and can be used in methods that compute metrics and plots like +``if "some_req" in reqs`` (see below for an example). + + Using an Evaluator ------------------ @@ -217,7 +247,7 @@ def eval(self, metrics=None, **plot_kwargs): .. code-block:: yaml - # gflownet/config/evaluator/my_evaluator.yaml + # config/evaluator/my_evaluator.yaml defaults: - base diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 6f82a7b84..0417776d6 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -323,8 +323,10 @@ def make_metrics(self, metrics=None): Parameters ---------- metrics : Union[str, List[str]], optional - Metrics to compute when running the `evaluator.eval()` function. Defaults to - None, i.e. all metrics in :py:const:`METRICS` are computed. + Metrics to compute when running the + :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.eval` + method. Defaults to ``None``, i.e. all metrics in :py:const:`METRICS` + are computed. Returns ------- @@ -566,6 +568,34 @@ def plot(self, **kwargs): @abstractmethod def eval(self, metrics=None, **plot_kwargs): + """ + The main method to compute metrics and intermediate results. + + This method should return a dict with two keys: "metrics" and "data". + + The "metrics" key should contain the new metric(s) and the "data" key should + contain the intermediate results that can be used to plot the new metric(s). + + Example + ------- + >>> metrics = None # use the default metrics from the config file + >>> results = gfne.eval(metrics=metrics) + >>> plots = gfne.plot(**results["data"]) + + >>> metrics = "all" # compute all metrics, regardless of the config + >>> results = gfne.eval(metrics=metrics) + + >>> metrics = ["l1", "kl"] # compute only the L1 and KL metrics + >>> results = gfne.eval(metrics=metrics) + + >>> metrics = "l1,kl" # alternative syntax + >>> results = gfne.eval(metrics=metrics) + + Parameters + ---------- + metrics : Union[str, dict, list], optional + Which metrics to compute, by default ``None``. + """ pass @abstractmethod From 21663da20c63435a3413eedb338e08d72eabae67 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 4 Mar 2024 17:27:04 -0500 Subject: [PATCH 076/106] reference logger --- gflownet/evaluator/__init__.py | 11 ++++++----- gflownet/evaluator/abstract.py | 22 ++++++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index da7c9e4e2..9d1422b03 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -266,10 +266,11 @@ def eval(self, metrics=None, **plot_kwargs): computed if user asks for a metric that requires such computations. By default, the train loop will call -:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` which itself calls -:py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` so if you override ``eval()`` -as above, the new metrics and plots will be computed and logged. +:py:meth:`~gflownet.evaluator.base.GFlowNetAbstractEvaluator.eval_and_log` which itself +calls :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` so if you override +``eval()`` as above, the new metrics and plots will be computed and logged. -Similarly, `eval_and_log` will compute the ``dict`` of figures as -``fig_dict = self.plot(**results["data"])`` where ``results`` is the output of ``eval``. +Similarly, :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` +will compute the ``dict`` of figures as ``fig_dict = self.plot(**results["data"])`` +where ``results`` is the output of ``eval``. """ diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 0417776d6..4ea32b00c 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -83,33 +83,35 @@ def eval_and_log(self, it, metrics=None): }, } """ -All metrics that can be computed by a GFlowNetEvaluator. Structured as a dict with the -metric names as keys and the metric display names and requirements as values. +All metrics that can be computed by a ``GFlowNetEvaluator``. + +Structured as a dict with the metric names as keys and the metric display +names and requirements as values. Requirements are used to decide which kind of data / samples is required to compute the metric. Display names are used to log the metrics and to display them in the console. -Implementations of :py:class:`GFlowNetAbstractEvaluator` should update -this dict and the :py:const:`ALL_REQS` set to include new metrics by implementing the -:py:method:`update_all_metrics_and_requirements` method. +Implementations of :py:class:`GFlowNetAbstractEvaluator` can add new metrics to +this dict by implementing the method +:py:method:`~gflownet.evaluator.abstract.define_new_metrics`. """ ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) """ -Union of all requirements of all metrics in `METRICS`. Computed from -:py:const:`METRICS`. +Union of all requirements of all metrics in :py:const:`METRICS`. """ class GFlowNetAbstractEvaluator(metaclass=ABCMeta): def __init__(self, gfn_agent=None, **config): """ - Base evaluator class for GFlowNetAgent. + Abstract evaluator class for :py:class:`GFlowNetAgent`. - In charge of evaluating the GFlowNetAgent, computing metrics plotting figures - and optionally logging results using the GFlowNetAgent's logger. + In charge of evaluating the :py:class:`GFlowNetAgent`, computing metrics + plotting figures and optionally logging results using the + :py:class:`GFlowNetAgent`'s :py:class:`Logger`. You can use the :py:method:`from_dir` or :py:method:`from_agent` class methods to easily instantiate this class from a run directory or an existing From 33fa8acea5ce8356a1f0161c41a9f8aafc4054b0 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Mon, 4 Mar 2024 17:36:25 -0500 Subject: [PATCH 077/106] more doc polih --- gflownet/evaluator/abstract.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 4ea32b00c..54fdb09e5 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -115,12 +115,12 @@ def __init__(self, gfn_agent=None, **config): You can use the :py:method:`from_dir` or :py:method:`from_agent` class methods to easily instantiate this class from a run directory or an existing - in-memory ``GFlowNetAgent``. + in-memory :py:class:`GFlowNetAgent`. Use :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.set_agent` - to set the evaluator's ``GFlowNetAgent`` after initialization if it was not - provided at instantiation as ``gfn_agent=``. + to set the evaluator's :py:class:`GFlowNetAgent` after initialization if it was + not provided at instantiation as ``GflowNetEvaluator(gfn_agent=...)``. This ``__init__`` function will call, in order: From 1939e2a5414268a3a2fafa750043b815a7413fdf Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Mon, 4 Mar 2024 19:45:17 -0500 Subject: [PATCH 078/106] Improve docs and refactor to `AbstractEvaluator` and `BaseEvaluator` --- README.md | 2 +- config/evaluator/base.yaml | 2 +- docs/conf.py | 16 ++- docs/contributors/example.rst | 16 ++- docs/contributors/write-documentation.rst | 3 +- gflownet/evaluator/__init__.py | 77 +++++++------ gflownet/evaluator/abstract.py | 110 +++++++++--------- gflownet/evaluator/base.py | 130 ++++++++++++++++------ gflownet/gflownet.py | 8 +- tests/gflownet/eval/test_base.py | 6 +- 10 files changed, 228 insertions(+), 142 deletions(-) diff --git a/README.md b/README.md index d58256682..a09e60923 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ The repository supports logging of train and evaluation metrics to [wandb.ai](ht Bibtex Format -```txt +```text @misc{hernandez-garcia2024, author = {Hernandez-Garcia, Alex and Saxena, Nikita and Volokhova, Alexandra and Koziarski, Michał and Sharma, Divya and Viviano, Joseph D and Carrier, Pierre Luc and Schmidt, Victor}, title = {gflownet}, diff --git a/config/evaluator/base.yaml b/config/evaluator/base.yaml index a679fdda7..03af9450f 100644 --- a/config/evaluator/base.yaml +++ b/config/evaluator/base.yaml @@ -1,4 +1,4 @@ -_target_: gflownet.evaluator.base.GFlowNetEvaluator +_target_: gflownet.evaluator.base.BaseEvaluator # config formerly from logger.test first_it: True diff --git a/docs/conf.py b/docs/conf.py index e4fdbcbf3..1210529c0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,11 +53,6 @@ templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - # -- Options for HTML output ------------------------------------------------- @@ -122,6 +117,7 @@ # sphinx.ext.intersphinx intersphinx_mapping = { "torch": ("https://pytorch.org/docs/stable", None), + "omegaconf": ("https://omegaconf.readthedocs.io/en/latest", None), } # sphinx.ext.autodoc & autoapi.extension @@ -179,3 +175,13 @@ "enable": True, "image": "./_static/images/gflownet-logo.png", } + + +def skip_util_classes(app, what, name, obj, skip, options): + return any( + name.startswith(f"gflownet.{p}") for p in ["envs", "proxy", "policy", "utils"] + ) + + +def setup(sphinx): + sphinx.connect("autoapi-skip-member", skip_util_classes) diff --git a/docs/contributors/example.rst b/docs/contributors/example.rst index 0b5a652b6..8d008c7c2 100644 --- a/docs/contributors/example.rst +++ b/docs/contributors/example.rst @@ -26,11 +26,17 @@ Remember, this works in docstrings *and* in stand-alone ``.rst`` files. Cool features: -Reference to a class: :class:`gflownet.proxy.crystals.dave.DAVE` (long), or another -:class:`~gflownet.gflownet.GFlowNetAgent` or to a method: -:meth:`~gflownet.gflownet.GFlowNetAgent.trajectorybalance_loss` -or to an external function :func:`torch.cuda.synchronize()` -(this <- needs to be listed in ``docs/conf.py:intersphinx_mapping``). +Reference code docs of: + +- A class: :class:`gflownet.envs.grid.Grid` (long format) +- Another class :class:`~gflownet.gflownet.GFlowNetAgent` (short format, by prepending ``~``) +- A method :meth:`~gflownet.gflownet.GFlowNetAgent.trajectorybalance_loss` +- Or even an external function :func:`torch.cuda.synchronize()` + +.. note + + External content should be listed in ``docs/conf.py:intersphinx_mapping``. + More info in the `Read The Docs documentation `_. An actual tutorial on ``.rst``: `ReStructured Text for those who know Markdown `_ diff --git a/docs/contributors/write-documentation.rst b/docs/contributors/write-documentation.rst index 8fdff1c4d..53daba247 100644 --- a/docs/contributors/write-documentation.rst +++ b/docs/contributors/write-documentation.rst @@ -13,7 +13,7 @@ Overview There are two major types of documentation: -1. **docstrings**: your code's docstrings will be automatically parsed by the documentation sofware (`Sphinx `_, more in `about shpinx`_). +1. **docstrings**: your code's docstrings will be automatically parsed by the documentation sofware (`Sphinx `_, more in :ref:`about shpinx`). 2. **Manual** documentation such as this document. This can be for instance a detailed installation procedure, a tutorial, a FAQ, a contributor's guide etc. you name it! **Both** are written in `ReStructured Text `_ (``.rst``) format. @@ -152,6 +152,7 @@ FAQ - `Hover X Ref `_ Enables tooltips to display contents on the hover of links - `Napoleon `_ enables the parsing of Google-style docstrings +.. _about shpinx: About Sphinx ------------ diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index 9d1422b03..45ee4d00f 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -1,31 +1,34 @@ """ -Create a new evaluator by subclassing this class and extending the :py:meth:`eval` +Create a new evaluator by subclassing this class and extending the :meth:`eval` method to add more metrics and plots. .. important:: - Only the :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_dir` - and :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_agent` - class methods should be used to instantiate this class. + Prefer the :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.from_dir` + and :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.from_agent` + class methods to instantiate an evaluator. Typical call stack: -1. :py:meth:`gflownet.gflownet.GFlowNetAgent.train` calls the evaluator's +1. :meth:`gflownet.gflownet.GFlowNetAgent.train` calls the evaluator's -2. :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.should_eval`. - If it returns ``True`` then :py:meth:`~gflownet.gflownet.GFlowNetAgent.train` calls +2. :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.should_eval`. + If it returns ``True`` then :meth:`~gflownet.gflownet.GFlowNetAgent.train` calls -3. :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.eval_and_log` +3. :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` which itself calls -4. :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` as +4. :meth:`~gflownet.evaluator.base.BaseEvaluator.eval` as ``results = self.eval(metrics=None)`` and then ``figs = self.plot(**results["data"])`` -5. finally, :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` logs the - results using the GFlowNetAgent's logger as +5. finally, :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` + logs the results using the GFlowNetAgent's logger as ``self.logger.log_metrics(results["metrics"])`` and ``self.logger.log_plots(figs)``. + +.. _evaluator basic concepts: + Basic concepts -------------- @@ -39,22 +42,30 @@ class methods should be used to instantiate this class. - ``None``: all metrics defined in the config file / in the evaluator's ``.config.metrics`` attribute will be computed. -- ``"all"``: all known metrics as defined in :py:const:`METRICS` will be computed. +- ``"all"``: all known metrics as defined in + :const:`~gflownet.evaluator.abstract.AbstractEvaluator.METRICS` + will be computed. - - Note that classes that inherit from :py:class:`GFlowNetEvaluator` can define new - metrics with the :py:meth:`define_new_metrics` method. + - Note that classes that inherit from + :class:`~gflownet.evaluator.abstract.AbstractEvaluator` can define new + metrics with the + :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.define_new_metrics` + method. - ``list``: a list of metric names to be computed. The names must be keys of - :py:const:`METRICS`. + :const:`~gflownet.evaluator.abstract.AbstractEvaluator.METRICS`. -- ``dict``: a dictionary that is a subset of :py:const:`METRICS`. +- ``dict``: a dictionary that is a subset of + :const:`~gflownet.evaluator.abstract.AbstractEvaluator.METRICS`. The concept of ``requirements`` is used to avoid unnecessary computations. If a metric requires a certain quantity to be computed, then the evaluator will only compute that -quantity if the metric is requested. This is done by the :py:meth:`make_requirements` +quantity if the metric is requested. This is done by the +:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.make_requirements` method and can be used in methods that compute metrics and plots like ``if "some_req" in reqs`` (see below for an example). +.. _using an evaluator: Using an Evaluator ------------------ @@ -62,10 +73,10 @@ class methods should be used to instantiate this class. .. code-block:: python # How to create a new evaluator: - from gflownet.evaluator.base import GFlowNetEvaluator + from gflownet.evaluator.base import BaseEvaluator gfn_run_dir = "PUT_YOUR_RUN_DIR_HERE" # a run dir contains a .hydra folder - gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) + gfne = BaseEvaluator.from_dir(gfn_run_dir) results = gfne.eval() for name, metric in results["metrics"].items(): @@ -87,9 +98,9 @@ class methods should be used to instantiate this class. .. code-block:: python # gflownet/evaluator/my_evaluator.py - from gflownet.evaluator.base import GFlowNetEvaluator, METRICS, ALL_REQS + from gflownet.evaluator.base import BaseEvaluator, METRICS, ALL_REQS - class MyEvaluator(GFlowNetEvaluator): + class MyEvaluator(BaseEvaluator): def define_new_metrics(self): ''' This method is called when the class is instantiated and is used to update @@ -106,7 +117,7 @@ def define_new_metrics(self): def my_custom_metric(self, some, arguments): ''' Your metric-computing method. It should return a dict with two keys: - "metrics" and "data". + ``"metrics"`` and ``"data"``. The "metrics" key should contain the new metric(s) and the "data" key should contain the intermediate results that can be used to plot the @@ -124,7 +135,7 @@ def my_custom_metric(self, some, arguments): Returns ------- dict - A dict with two keys: "metrics" and "data". + A dict with two keys: ``"metrics"`` and ``"data"``. ''' intermediate = some + arguments @@ -214,7 +225,7 @@ def eval(self, metrics=None, **plot_kwargs): ''' Your custom eval method. - It should return a dict with two keys: "metrics" and "data". + It should return a dict with two keys: ``"metrics"`` and ``"data"``. It will be called by the `eval_and_log` method to log the metrics, @@ -227,7 +238,7 @@ def eval(self, metrics=None, **plot_kwargs): Returns ------- dict - A dict with two keys: "metrics" and "data". + A dict with two keys: ``"metrics"`` and ``"data"``. ''' metrics = self.make_metrics(metrics) reqs = self.make_requirements(metrics=metrics) @@ -260,17 +271,19 @@ def eval(self, metrics=None, **plot_kwargs): In the previous example, the ``define_new_metrics`` method is used to define new metrics and associated requirements. It will be called when the -``MyEvaluator`` class is instantiated, in the init of ``GFlowNetAbstractEvaluator``. +``MyEvaluator`` class is instantiated, in the init of +:class:`~gflownet.evaluator.abstract.AbstractEvaluator`. By defining a new requirement, you ensure that the new metrics and plots will only be computed if user asks for a metric that requires such computations. -By default, the train loop will call -:py:meth:`~gflownet.evaluator.base.GFlowNetAbstractEvaluator.eval_and_log` which itself -calls :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval` so if you override -``eval()`` as above, the new metrics and plots will be computed and logged. +By default, the training loop will call +:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` which itself +calls :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval` so if you +override ``eval()`` as above, the new metrics and plots will be computed and logged. -Similarly, :py:meth:`~gflownet.evaluator.base.GFlowNetEvaluator.eval_and_log` +Similarly, :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` will compute the ``dict`` of figures as ``fig_dict = self.plot(**results["data"])`` -where ``results`` is the output of ``eval``. +where ``results`` is the output of +:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval`. """ diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 54fdb09e5..5a3c89536 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -6,13 +6,15 @@ Should not be used directly, but subclassed to implement specific evaluators for different tasks and environments. -See :py:class:`~gflownet.evaluator.base.GFlowNetEvaluator` for a default, +See :class:`~gflownet.evaluator.base.BaseEvaluator` for a default, concrete implementation of this abstract class. This class handles some logic that will be the same for all evaluators. -The only requirements for a subclass are to implement the `plot` and `eval` methods +The only requirements for a subclass are to implement the +:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval` and +:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.plot` methods which will be called by the -:py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.eval_and_log` method. +:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` method: .. code-block:: python @@ -30,8 +32,8 @@ def eval_and_log(self, it, metrics=None): self.logger.log_metrics(mertics_to_log, it, self.gfn.use_context) self.logger.log_plots(figs, it, use_context=self.gfn.use_context) -See :py:mod:`gflownet.evaluator` for a full-fledged example and -:py:mod:`gflownet.evaluator.base` for a concrete implementation of this abstract class. +See :mod:`gflownet.evaluator` for a full-fledged example and +:mod:`gflownet.evaluator.base` for a concrete implementation of this abstract class. """ import os @@ -42,7 +44,11 @@ def eval_and_log(self, it, metrics=None): from gflownet.utils.common import load_gflow_net_from_run_path +# purposefully non-documented object, hidden from Sphinx docs _sentinel = object() +""" +A sentinel object to be used as a default value for arguments that could be None. +""" METRICS = { "l1": { @@ -83,7 +89,7 @@ def eval_and_log(self, it, metrics=None): }, } """ -All metrics that can be computed by a ``GFlowNetEvaluator``. +All metrics that can be computed by a ``BaseEvaluator``. Structured as a dict with the metric names as keys and the metric display names and requirements as values. @@ -93,70 +99,60 @@ def eval_and_log(self, it, metrics=None): Display names are used to log the metrics and to display them in the console. -Implementations of :py:class:`GFlowNetAbstractEvaluator` can add new metrics to +Implementations of :class:`AbstractEvaluator` can add new metrics to this dict by implementing the method -:py:method:`~gflownet.evaluator.abstract.define_new_metrics`. +:meth:`AbstractEvaluator.define_new_metrics`. """ ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]]) """ -Union of all requirements of all metrics in :py:const:`METRICS`. +Union of all requirements of all metrics in :const:`METRICS`. """ -class GFlowNetAbstractEvaluator(metaclass=ABCMeta): +class AbstractEvaluator(metaclass=ABCMeta): def __init__(self, gfn_agent=None, **config): """ - Abstract evaluator class for :py:class:`GFlowNetAgent`. + Abstract evaluator class for :class:`GFlowNetAgent`. - In charge of evaluating the :py:class:`GFlowNetAgent`, computing metrics + In charge of evaluating the :class:`GFlowNetAgent`, computing metrics plotting figures and optionally logging results using the - :py:class:`GFlowNetAgent`'s :py:class:`Logger`. + :class:`GFlowNetAgent`'s :class:`Logger`. - You can use the :py:method:`from_dir` or :py:method:`from_agent` class methods + You can use the :meth:`from_dir` or :meth:`from_agent` class methods to easily instantiate this class from a run directory or an existing - in-memory :py:class:`GFlowNetAgent`. + in-memory :class:`GFlowNetAgent`. Use - :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.set_agent` - to set the evaluator's :py:class:`GFlowNetAgent` after initialization if it was + :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.set_agent` + to set the evaluator's :class:`GFlowNetAgent` after initialization if it was not provided at instantiation as ``GflowNetEvaluator(gfn_agent=...)``. This ``__init__`` function will call, in order: - 1. :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.update_all_metrics_and_requirements` - which uses new metrics - defined in the - :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.define_new_metrics` - method to update the global :py:const:`METRICS` and :py:const:`ALL_REQS` - variables in classes inheriting from :py:class:`GFlowNetAbstractEvaluator`. + 1. :meth:`update_all_metrics_and_requirements` which uses new metrics defined in + the :meth:`define_new_metrics` method to update the global :const:`METRICS` + and :const:`ALL_REQS` variables in classes inheriting from + :class:`AbstractEvaluator`. 2. ``self.metrics = self.make_metrics(self.config.metrics)`` using - :py:method:`make_metrics` + :meth:`make_metrics` - 3. ``self.reqs = self.make_requirements()`` using - :py:method:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.make_requirements` + 3. ``self.reqs = self.make_requirements()`` using :meth:`make_requirements` Arguments --------- gfn_agent : GFlowNetAgent, optional The GFlowNetAgent to evaluate. By default None. Should be set using the - :py:method:`from_dir` or :py:method:`from_agent` class methods. + :meth:`from_dir` or :meth:`from_agent` class methods. config : dict The configuration of the evaluator. Will be converted to an OmegaConf instance and stored in the ``self.config`` attribute. - Raises - ------ - NotImplementedError - If the `sentinel` keyword argument is not `_sentinel`, which is used to - prevent instantiation of the base class without using the `from_dir` or - `from_agent` class methods. - Attributes ---------- - config : OmegaConf + config : :class:`omegaconf.OmegaConf` The configuration of the evaluator. metrics : dict Dictionary of metrics to compute, with the metric names as keys and the @@ -167,7 +163,7 @@ def __init__(self, gfn_agent=None, **config): logger : Logger The logger to use to log the results of the evaluation. Will be set to the GFlowNetAgent's logger. - gfn: :py:class:`GFlowNetAgent` + gfn: :class:`GFlowNetAgent` The GFlowNetAgent to evaluate. """ @@ -226,7 +222,7 @@ def define_new_metrics(self): Returns ------- dict - Dictionary of new metrics to add to the global :py:const:`METRICS` dict. + Dictionary of new metrics to add to the global :const:`METRICS` dict. """ pass @@ -244,7 +240,7 @@ def update_all_metrics_and_requirements(self): @classmethod def from_dir( - cls: "GFlowNetAbstractEvaluator", + cls: "AbstractEvaluator", path: Union[str, os.PathLike], no_wandb: bool = True, print_config: bool = False, @@ -252,11 +248,11 @@ def from_dir( load_final_ckpt: bool = True, ): """ - Instantiate a GFlowNetEvaluator from a run directory. + Instantiate a BaseEvaluator from a run directory. Parameters ---------- - cls : GFlowNetEvaluator + cls : BaseEvaluator Class to instantiate. path : Union[str, os.PathLike] Path to the run directory from which to load the GFlowNetAgent. @@ -271,8 +267,8 @@ def from_dir( Returns ------- - GFlowNetEvaluator - Instance of GFlowNetEvaluator with the GFlowNetAgent loaded from the run. + BaseEvaluator + Instance of BaseEvaluator with the GFlowNetAgent loaded from the run. """ gfn_agent, _ = load_gflow_net_from_run_path( path, @@ -286,19 +282,19 @@ def from_dir( @classmethod def from_agent(cls, gfn_agent): """ - Instantiate a GFlowNetEvaluator from a GFlowNetAgent. + Instantiate a BaseEvaluator from a GFlowNetAgent. Parameters ---------- - cls : GFlowNetEvaluator + cls : BaseEvaluator Evaluator class to instantiate. gfn_agent : GFlowNetAgent - Instance of GFlowNetAgent to use for the GFlowNetEvaluator. + Instance of GFlowNetAgent to use for the BaseEvaluator. Returns ------- - GFlowNetEvaluator - Instance of GFlowNetEvaluator with the provided GFlowNetAgent. + BaseEvaluator + Instance of BaseEvaluator with the provided GFlowNetAgent. """ from gflownet.gflownet import GFlowNetAgent @@ -316,19 +312,18 @@ def make_metrics(self, metrics=None): - If ``None``, all metrics are selected. - If a string, it can be a comma-separated list of metric names, with or without spaces. - - If a list, it should be a list of metric names (keys of :py:const:`METRICS`). + - If a list, it should be a list of metric names (keys of :const:`METRICS`). - If a dict, its keys should be metric names and its values will be ignored: - they will be assigned from :py:const:`METRICS`. + they will be assigned from :const:`METRICS`. - All metrics must be in :py:const:`METRICS`. + All metrics must be in :const:`METRICS`. Parameters ---------- metrics : Union[str, List[str]], optional Metrics to compute when running the - :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.eval` - method. Defaults to ``None``, i.e. all metrics in :py:const:`METRICS` - are computed. + :meth:`.eval` method. Defaults to ``None``, i.e. all metrics in + :const:`METRICS` are computed. Returns ------- @@ -339,7 +334,7 @@ def make_metrics(self, metrics=None): Raises ------ ValueError - If a metric name is not in :py:const:`METRICS`. + If a metric name is not in :const:`METRICS`. """ if metrics is None: assert self.metrics is not _sentinel, ( @@ -401,8 +396,7 @@ def make_requirements(self, reqs=None, metrics=None): By default ``None``. metrics : Union[str, List[str], dict], optional The metrics to compute requirements for. If not a dict, will be passed to - :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.make_metrics``. - By default None. + :meth:`make_metrics`. By default None. Returns ------- @@ -573,7 +567,7 @@ def eval(self, metrics=None, **plot_kwargs): """ The main method to compute metrics and intermediate results. - This method should return a dict with two keys: "metrics" and "data". + This method should return a dict with two keys: ``"metrics"`` and ``"data"``. The "metrics" key should contain the new metric(s) and the "data" key should contain the intermediate results that can be used to plot the new metric(s). @@ -593,6 +587,8 @@ def eval(self, metrics=None, **plot_kwargs): >>> metrics = "l1,kl" # alternative syntax >>> results = gfne.eval(metrics=metrics) + See :ref:`evaluator basic concepts` for more details about ``metrics``. + Parameters ---------- metrics : Union[str, dict, list], optional diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 4bec305b3..93e69bcff 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -1,17 +1,20 @@ """ -Base evaluator class for GFlowNetAgent. +Base evaluator class for a :class:`~gflownet.gflownet.GFlowNetAgent`. -In charge of evaluating a generic GFlowNetAgent, computing metrics plotting figures and -optionally logging results using the GFlowNetAgent's logger. +In charge of evaluating a generic :class:`~gflownet.gflownet.GFlowNetAgent`, +computing metrics plotting figures and optionally logging results using the +:class:`~gflownet.gflownet.GFlowNetAgent`'s :class:`~gflownet.utils.logger.Logger`. -Take it as example to implement your own evaluator class for your custom use-case. +Take this :class:`BaseEvaluator` as example to implement your own evaluator class +for your custom use-case. .. important:: - Only the :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_dir` - and :py:meth:`~gflownet.evaluator.abstract.GFlowNetAbstractEvaluator.from_agent` - class methods should be used to instantiate this class. + Prefer the :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.from_dir` + and :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.from_agent` + class methods to instantiate an evaluator. +See :ref:`using an evaluator` for more details about how to use an Evaluator. """ import copy @@ -25,12 +28,35 @@ class methods should be used to instantiate this class. from gflownet.evaluator.abstract import ALL_REQS # noqa from gflownet.evaluator.abstract import METRICS # noqa -from gflownet.evaluator.abstract import GFlowNetAbstractEvaluator +from gflownet.evaluator.abstract import AbstractEvaluator from gflownet.utils.batch import Batch from gflownet.utils.common import batch_with_rest, tfloat, torch2np -class GFlowNetEvaluator(GFlowNetAbstractEvaluator): +class BaseEvaluator(AbstractEvaluator): + + def __init__(self, gfn_agent=None, **config): + """ + Base evaluator class for GFlowNetAgent. + + In particular, implements the :meth:`eval` with: + + - :meth:`compute_log_prob_metrics` to compute log-probability metrics. + - :meth:`compute_density_metrics` to compute density metrics. + + And the :meth:`plot` method with: + + - The :class:`~gflownet.envs.base.GFlowNetEnv`'s :meth:`plot_reward_samples` + method. + - The :class:`~gflownet.envs.base.GFlowNetEnv`'s :meth:`plot_kde` method if + it exists, for both the ``kde_pred`` and ``kde_true`` arguments if they are + returned in the ``"data"`` dict of the :meth:`eval` method. + + See the :class:`~gflownet.evaluator.abstract.AbstractEvaluator` for more + details about other methods and attributes, including the + :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.__init__`. + """ + super().__init__(gfn_agent, **config) @torch.no_grad() def eval_top_k(self, it, gfn_states=None, random_states=None): @@ -148,6 +174,36 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): } def compute_log_prob_metrics(self, x_tt, metrics=None): + """ + Compute log-probability metrics for the given test data. + + Uses :meth:`~gflownet.gflownet.GFlowNetAgent.estimate_logprobs_data`. + + Known metrics: + + - ``mean_logprobs_std``: Mean of the standard deviation of the log-probabilities. + - ``mean_probs_std``: Mean of the standard deviation of the probabilities. + - ``corr_prob_traj_rewards``: Correlation between the probabilities and the + rewards. + - ``var_logrewards_logp``: Variance of the log-rewards minus the log-probabilities. + - ``nll_tt``: Negative log-likelihood of the test data. + - ``logprobs_std_nll_ratio``: Ratio of the mean of the standard deviation of the + log-probabilities over the negative log-likelihood of the test data. + + + Parameters + ---------- + x_tt : torch.Tensor + Test data. + metrics : List[str], optional + List of metrics to compute, by default ``None`` i.e. the evaluator's + ``self.metrics`` + + Returns + ------- + dict + Computed dict of metrics and data as ``{"metrics": {str: float}}``. + """ metrics = self.make_metrics(metrics) reqs = self.make_requirements(metrics=metrics) @@ -201,6 +257,38 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): } def compute_density_metrics(self, x_tt, dict_tt, metrics=None): + """ + Compute density metrics for the given test data. + + Known metrics: + + - ``l1``: L1 error between the true and predicted densities. + - ``kl``: KL divergence between the true and predicted densities. + - ``jsd``: Jensen-Shannon divergence between the true and predicted densities. + + Returned data in the ``"data"`` sub-dict: + + - ``x_sampled``: Sampled states from the GFN. + - ``kde_pred``: KDE policy as per + :meth:`~gflownet.envs.base.GFlowNetEnv.fit_kde`. + - ``kde_true``: True KDE. + + Parameters + ---------- + x_tt : torch.Tensor + Test data. + dict_tt : dict + Dictionary of test data. + metrics : List[str], optional + List of metrics to compute, by default ``None`` i.e. the evaluator's + ``self.metrics`` + + Returns + ------- + dict + Computed dict of metrics and data as + ``{"metrics": {str: float}, "data": {str: object}}``. + """ metrics = self.make_metrics(metrics) density_metrics = {} @@ -434,27 +522,3 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): "GFlowNet KDE Policy": fig_kde_pred, "Reward KDE": fig_kde_true, } - - -if __name__ == "__main__": - # Try using the GFlowNetEvaluator by running this script from the root: - # $ ipython - # In [1]: run gflownet/evaluator/base.py - # - # Note: this will not work on previous checkpoints whose config does not contain an - # `eval` entry, you have to run one. Add `eval.checkpoint_period=10` to quickly - # have a checkpoint to test. - - gfn_run_dir = "PUT_YOUR_RUN_DIR_HERE" # a run dir contains a .hydra folder - - gfne = GFlowNetEvaluator.from_dir(gfn_run_dir) - results = gfne.eval() - - for name, metric in results["metrics"].items(): - print(f"{name:20}: {metric:.4f}") - - print( - "Available figures in results['figs']:", - ", ".join([fname for fname, fig in results["figs"].items() if fig is not None]) - or "None", - ) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index a5d23fe91..1c4c1b881 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -17,7 +17,7 @@ from tqdm import tqdm from gflownet.envs.base import GFlowNetEnv -from gflownet.evaluator.base import GFlowNetEvaluator +from gflownet.evaluator.base import BaseEvaluator from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer from gflownet.utils.common import ( @@ -94,7 +94,7 @@ def __init__( (`gflownet/utils/logger.py:Logger`). num_empirical_loss : int Number of empirical loss samples to be used for training. - evaluator : gflownet.evaluator.base.GFlowNetEvaluator + evaluator : gflownet.evaluator.base.BaseEvaluator :py:mod:`~gflownet.evaluator` ``Evaluator`` instance. state_flow : dict, optional State flow config dictionary. See `gflownet.yaml:state_flow` for details. By @@ -104,10 +104,10 @@ def __init__( `active_learning: bool` flag. By default False. replay_sampling : str, optional Type of sampling for the replay buffer. See - :method:`~gflownet.utils.buffer.select`. By default "permutation". + :meth:`~gflownet.utils.buffer.select`. By default "permutation". train_sampling : str, optional Type of sampling for the train buffer (offline backward trajectories). See - :method:`~gflownet.utils.buffer.select`. By default "permutation". + :meth:`~gflownet.utils.buffer.select`. By default "permutation". Raises ------ diff --git a/tests/gflownet/eval/test_base.py b/tests/gflownet/eval/test_base.py index c3feed63c..9ba499f2d 100644 --- a/tests/gflownet/eval/test_base.py +++ b/tests/gflownet/eval/test_base.py @@ -4,7 +4,7 @@ import pytest from omegaconf import OmegaConf -from gflownet.evaluator.base import METRICS, GFlowNetEvaluator, _sentinel +from gflownet.evaluator.base import METRICS, BaseEvaluator, _sentinel PERIOD_STEP_TARGET = [ (0, 0, False), @@ -28,7 +28,7 @@ (3, 3, True), ] -CONSTANT_EVALUATOR = GFlowNetEvaluator( +CONSTANT_EVALUATOR = BaseEvaluator( gfn_agent=OmegaConf.create({"eval_config": {"metrics": "all"}, "logger": {}}), sentinel=_sentinel, ) @@ -42,7 +42,7 @@ def dummy_evaluator(config_for_tests): "logger": config_for_tests.logger, } ) - return GFlowNetEvaluator(gfn_agent=gfna_dummy, sentinel=_sentinel) + return BaseEvaluator(gfn_agent=gfna_dummy, sentinel=_sentinel) @pytest.fixture From ef99765edd110387c2c6de6058ec09b25a2dd33c Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Mon, 4 Mar 2024 19:45:56 -0500 Subject: [PATCH 079/106] comment-out trailing dev docs rendering filter Keep for future dev use --- docs/conf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 1210529c0..c3199034b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -177,11 +177,11 @@ } -def skip_util_classes(app, what, name, obj, skip, options): - return any( - name.startswith(f"gflownet.{p}") for p in ["envs", "proxy", "policy", "utils"] - ) +# def skip_util_classes(app, what, name, obj, skip, options): +# return any( +# name.startswith(f"gflownet.{p}") for p in ["envs", "proxy", "policy", "utils"] +# ) -def setup(sphinx): - sphinx.connect("autoapi-skip-member", skip_util_classes) +# def setup(sphinx): +# sphinx.connect("autoapi-skip-member", skip_util_classes) From 02e31eff9d73b4200deaed242f25d70bb71a0dec Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 08:36:50 -0500 Subject: [PATCH 080/106] improve logging --- gflownet/gflownet.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 1c4c1b881..2ac65648a 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn from torch.distributions import Bernoulli -from tqdm import tqdm +from tqdm import tqdm, trange from gflownet.envs.base import GFlowNetEnv from gflownet.evaluator.base import BaseEvaluator @@ -964,19 +964,23 @@ def estimate_logprobs_data( "Sampling backward actions from test data to estimate logprobs...", flush=True, ) - pbar = tqdm(total=n_states) + pbar = tqdm(total=n_states, disable=not self.logger.progress) + pbar2 = trange( + end_batch * n_trajectories, disable=not self.logger.progress, leave=False + ) while init_batch < n_states: batch = Batch(env=self.env, device=self.device, float_type=self.float) # Create an environment for each data point and trajectory and set the state envs = [] + pbar2.reset() for state_idx in range(init_batch, end_batch): for traj_idx in range(n_trajectories): idx = int(mult_indices * state_idx + traj_idx) env = self.env.copy().reset(idx) env.set_state(states_term[state_idx], done=True) envs.append(env) + pbar2.update(1) # Sample trajectories - max_iters = n_trajectories * max_iters_per_traj while envs: # Sample backward actions actions = self.sample_actions( @@ -1023,6 +1027,8 @@ def estimate_logprobs_data( ) logprobs_std = torch.std(logprobs_estimates_bs, dim=-1) probs_std = torch.std(torch.exp(logprobs_estimates_bs), dim=-1) + pbar.close() + pbar2.close() print("Done computing logprobs", flush=True) return logprobs_estimates, logprobs_std, probs_std From 64794ad629ce73343b16bec9eb40688b19eba6b8 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 08:37:06 -0500 Subject: [PATCH 081/106] use `evaluator` namesmace --- gflownet/utils/common.py | 3 +-- main.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index fffe8cc9b..3a8056e96 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -259,7 +259,7 @@ def gflownet_from_config(config): ) # The evaluator is used to compute metrics and plots - evaluator = instantiate(config.eval) + evaluator = instantiate(config.evaluator) # The policy is used to model the probability of a forward/backward action forward_config = parse_policy_config(config, kind="forward") @@ -303,7 +303,6 @@ def gflownet_from_config(config): buffer=config.env.buffer, logger=logger, evaluator=evaluator, - full_config=config, ) return gflownet diff --git a/main.py b/main.py index 036d9fc45..ab812f63d 100644 --- a/main.py +++ b/main.py @@ -46,7 +46,7 @@ def main(config): float_precision=config.float_precision, ) # The evaluator is used to compute metrics and plots - evaluator = hydra.utils.instantiate(config.eval) + evaluator = hydra.utils.instantiate(config.evaluator) # The policy is used to model the probability of a forward/backward action forward_config = parse_policy_config(config, kind="forward") backward_config = parse_policy_config(config, kind="backward") From b3be8607c5e16ab0230628aae9f2802ebeb49b49 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 08:47:10 -0500 Subject: [PATCH 082/106] adapt jay --- config/experiments/scrabble/jay.yaml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/config/experiments/scrabble/jay.yaml b/config/experiments/scrabble/jay.yaml index 116595238..6af4317f2 100644 --- a/config/experiments/scrabble/jay.yaml +++ b/config/experiments/scrabble/jay.yaml @@ -5,6 +5,7 @@ defaults: - override /env: scrabble - override /gflownet: trajectorybalance + - override /evaluator: base - override /proxy: scrabble - override /logger: wandb - override /user: alex @@ -51,21 +52,21 @@ policy: shared_weights: False checkpoint: backward +# Evaluator +period: 500 +n: 1000 +checkpoints_period: 500 + # WandB logger: do: online: true lightweight: True project_name: "scrabble" - tags: + tags: - gflownet - discrete - scrabble - test: - period: 500 - n: 1000 - checkpoints: - period: 500 # Hydra hydra: From 1d883d39201bf3cf134882144dfef78049f8cc01 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 08:47:22 -0500 Subject: [PATCH 083/106] move metrics to base --- gflownet/evaluator/abstract.py | 69 ++++++++++++++-------------------- gflownet/evaluator/base.py | 42 ++++++++++++++++++++- 2 files changed, 69 insertions(+), 42 deletions(-) diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 5a3c89536..71acfcfcd 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -50,44 +50,7 @@ def eval_and_log(self, it, metrics=None): A sentinel object to be used as a default value for arguments that could be None. """ -METRICS = { - "l1": { - "display_name": "L1 error", - "requirements": ["density"], - }, - "kl": { - "display_name": "KL Div.", - "requirements": ["density"], - }, - "jsd": { - "display_name": "Jensen Shannon Div.", - "requirements": ["density"], - }, - "corr_prob_traj_rewards": { - "display_name": "Corr. (test probs., rewards)", - "requirements": ["log_probs", "reward_batch"], - }, - "var_logrewards_logp": { - "display_name": "Var(logR - logp) test", - "requirements": ["log_probs", "reward_batch"], - }, - "nll_tt": { - "display_name": "NLL of test data", - "requirements": ["log_probs"], - }, - "mean_logprobs_std": { - "display_name": "Mean BS Std(logp)", - "requirements": ["log_probs"], - }, - "mean_probs_std": { - "display_name": "Mean BS Std(p)", - "requirements": ["log_probs"], - }, - "logprobs_std_nll_ratio": { - "display_name": "BS Std(logp) / NLL", - "requirements": ["log_probs"], - }, -} +METRICS = {} """ All metrics that can be computed by a ``BaseEvaluator``. @@ -163,8 +126,6 @@ def __init__(self, gfn_agent=None, **config): logger : Logger The logger to use to log the results of the evaluation. Will be set to the GFlowNetAgent's logger. - gfn: :class:`GFlowNetAgent` - The GFlowNetAgent to evaluate. """ self._gfn_agent = gfn_agent @@ -181,14 +142,40 @@ def __init__(self, gfn_agent=None, **config): @property def gfn(self): + """ + Get the ``GFlowNetAgent`` to evaluate. + + This is a read-only property. Use the :meth:`set_agent` method to set + the ``GFlowNetAgent``. + + Returns + ------- + :class:`GFlowNetAgent` + The ``GFlowNetAgent`` to evaluate. + + Raises + ------ + ValueError + If the ``GFlowNetAgent`` has not been set. + """ if type(self._gfn_agent).__name__ != "GFlowNetAgent": raise ValueError( "The GFlowNetAgent has not been set. Use the `from_dir` or `from_agent`" - + " class methods to instantiate this class or set the `gfn` attribute." + + " class methods to instantiate this class or the `set_agent` method" ) return self._gfn_agent def set_agent(self, gfn_agent): + """ + Set the ``GFlowNetAgent`` to evaluate after initialization. + + It is then accessible through the ``self.gfn`` property. + + Parameters + ---------- + gfn_agent : :class:`GFlowNetAgent` + The ``GFlowNetAgent`` to evaluate. + """ assert type(gfn_agent).__name__ == "GFlowNetAgent", ( "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " + f"{type(gfn_agent)}." diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 93e69bcff..7c36786c2 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -58,6 +58,46 @@ def __init__(self, gfn_agent=None, **config): """ super().__init__(gfn_agent, **config) + def define_new_metrics(self): + return { + "l1": { + "display_name": "L1 error", + "requirements": ["density"], + }, + "kl": { + "display_name": "KL Div.", + "requirements": ["density"], + }, + "jsd": { + "display_name": "Jensen Shannon Div.", + "requirements": ["density"], + }, + "corr_prob_traj_rewards": { + "display_name": "Corr. (test probs., rewards)", + "requirements": ["log_probs", "reward_batch"], + }, + "var_logrewards_logp": { + "display_name": "Var(logR - logp) test", + "requirements": ["log_probs", "reward_batch"], + }, + "nll_tt": { + "display_name": "NLL of test data", + "requirements": ["log_probs"], + }, + "mean_logprobs_std": { + "display_name": "Mean BS Std(logp)", + "requirements": ["log_probs"], + }, + "mean_probs_std": { + "display_name": "Mean BS Std(p)", + "requirements": ["log_probs"], + }, + "logprobs_std_nll_ratio": { + "display_name": "BS Std(logp) / NLL", + "requirements": ["log_probs"], + }, + } + @torch.no_grad() def eval_top_k(self, it, gfn_states=None, random_states=None): """ @@ -296,7 +336,7 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): x_sampled = density_true = density_pred = None - if self.gfn.buffer.test_type is not None and self.gfn.buffer.test_type == "all": + if self.gfn.buffer.test_type == "all": batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False) assert batch.is_valid() x_sampled = batch.get_terminating_states() From 92182b11acfed2bd1f519b09c372036d70868181 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 09:03:21 -0500 Subject: [PATCH 084/106] fix tests --- config/tests.yaml | 2 +- .../gflownet/{eval => evaluator}/test_base.py | 23 ++++++------------- 2 files changed, 8 insertions(+), 17 deletions(-) rename tests/gflownet/{eval => evaluator}/test_base.py (92%) diff --git a/config/tests.yaml b/config/tests.yaml index ba3346aea..ddb6503c9 100644 --- a/config/tests.yaml +++ b/config/tests.yaml @@ -5,7 +5,7 @@ defaults: - policy: mlp - logger: base - user: alex - - eval: base + - evaluator: base - _self_ # Device diff --git a/tests/gflownet/eval/test_base.py b/tests/gflownet/evaluator/test_base.py similarity index 92% rename from tests/gflownet/eval/test_base.py rename to tests/gflownet/evaluator/test_base.py index 9ba499f2d..85f5f81b5 100644 --- a/tests/gflownet/eval/test_base.py +++ b/tests/gflownet/evaluator/test_base.py @@ -4,7 +4,8 @@ import pytest from omegaconf import OmegaConf -from gflownet.evaluator.base import METRICS, BaseEvaluator, _sentinel +from gflownet.evaluator.abstract import METRICS, _sentinel +from gflownet.evaluator.base import BaseEvaluator PERIOD_STEP_TARGET = [ (0, 0, False), @@ -28,28 +29,17 @@ (3, 3, True), ] -CONSTANT_EVALUATOR = BaseEvaluator( - gfn_agent=OmegaConf.create({"eval_config": {"metrics": "all"}, "logger": {}}), - sentinel=_sentinel, -) +CONSTANT_EVALUATOR = BaseEvaluator(metrics="all") @pytest.fixture def dummy_evaluator(config_for_tests): - gfna_dummy = OmegaConf.create( - { - "eval_config": config_for_tests.eval, - "logger": config_for_tests.logger, - } - ) - return BaseEvaluator(gfn_agent=gfna_dummy, sentinel=_sentinel) + return BaseEvaluator(**config_for_tests.evaluator) @pytest.fixture def constant_evaluator(): # faster fixture for state-less tests - CONSTANT_EVALUATOR.config = OmegaConf.create( - {"eval_config": {"metrics": "all"}, "logger": {}} - ) + CONSTANT_EVALUATOR.config = OmegaConf.create({"metrics": "all"}) return CONSTANT_EVALUATOR @@ -234,6 +224,7 @@ def test__eval(gflownet_for_tests, parameterization): assert Path("./replay.pkl").exists() # results: {"metrics": dict[str, float], "figs": list[plt.Figure]} results = gflownet_for_tests.evaluator.eval() + figs = gflownet_for_tests.evaluator.plot(**results["data"]) for k, v in results["metrics"].items(): assert isinstance(k, str) @@ -244,7 +235,7 @@ def test__eval(gflownet_for_tests, parameterization): elif parameterization == "grid_length_4": pass elif parameterization == "ctorus": - for figname, fig in results["figs"].items(): + for figname, fig in figs.items(): assert isinstance(figname, str) assert isinstance(fig, plt.Figure) else: From 8139ab0f928ced1bef97754615a7a2dbbb915a40 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 09:40:44 -0500 Subject: [PATCH 085/106] clean up prints --- gflownet/gflownet.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 2ac65648a..172b9ccd0 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -926,7 +926,6 @@ def estimate_logprobs_data( probs_std: torch.tensor Bootstrap std of the torch.exp(logprobs_estimates) """ - print("Compute logprobs...", flush=True) times = {} # Determine terminating states if isinstance(data, list): @@ -960,19 +959,23 @@ def estimate_logprobs_data( mult_indices = max(n_states, n_trajectories) init_batch = 0 end_batch = min(batch_size, n_states) - print( - "Sampling backward actions from test data to estimate logprobs...", - flush=True, + pbar = tqdm( + total=n_states, + disable=not self.logger.progress, + leave=False, + desc="Sampling backward actions from test data to estimate logprobs", ) - pbar = tqdm(total=n_states, disable=not self.logger.progress) pbar2 = trange( - end_batch * n_trajectories, disable=not self.logger.progress, leave=False + end_batch * n_trajectories, + disable=not self.logger.progress, + leave=False, + desc="Setting env terminal states", ) while init_batch < n_states: batch = Batch(env=self.env, device=self.device, float_type=self.float) # Create an environment for each data point and trajectory and set the state envs = [] - pbar2.reset() + pbar2.reset((end_batch - init_batch) * n_trajectories) for state_idx in range(init_batch, end_batch): for traj_idx in range(n_trajectories): idx = int(mult_indices * state_idx + traj_idx) @@ -1029,7 +1032,6 @@ def estimate_logprobs_data( probs_std = torch.std(torch.exp(logprobs_estimates_bs), dim=-1) pbar.close() pbar2.close() - print("Done computing logprobs", flush=True) return logprobs_estimates, logprobs_std, probs_std def train(self): From dce8abf18ba876fc5e07fa007979ac8f5656dfec Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 09:40:49 -0500 Subject: [PATCH 086/106] use evaluator --- config/experiments/icml23/ctorus.yaml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/config/experiments/icml23/ctorus.yaml b/config/experiments/icml23/ctorus.yaml index eb33c5d41..d7477e64e 100644 --- a/config/experiments/icml23/ctorus.yaml +++ b/config/experiments/icml23/ctorus.yaml @@ -2,6 +2,7 @@ defaults: - override /env: ctorus + - override /evaluator: base - override /gflownet: trajectorybalance - override /proxy: torus - override /logger: wandb @@ -33,6 +34,12 @@ gflownet: lr_z_mult: 1000 n_train_steps: 5000 +# Evaluator +evaluator: + period: 25 + n: 1000 + checkpoints_period: 500 + # Policy policy: forward: @@ -51,15 +58,10 @@ policy: logger: lightweight: True project_name: "Continuous GFlowNet" - tags: + tags: - gflownet - continuous - ctorus - test: - period: 25 - n: 1000 - checkpoints: - period: 500 # Hydra hydra: From 0322a41a9ca5efb3975857bb9cbdbeb42d538301 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 09:49:05 -0500 Subject: [PATCH 087/106] evaluator in tests instantiate --- tests/gflownet/envs/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 294743c3c..f2abdc11b 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -450,6 +450,7 @@ def test__gflownet_minimal_runs(self, n_repeat=1): device=config.device, float_precision=config.float_precision, ) + evaluator = hydra.utils.instantiate(config.evaluator) # Policy forward_config = parse_policy_config(config, kind="forward") @@ -482,7 +483,7 @@ def test__gflownet_minimal_runs(self, n_repeat=1): backward_policy=backward_policy, buffer=config.env.buffer, logger=logger, - eval_config=config.eval, + evaluator=evaluator, ) gflownet.train() assert True From 8b27ac3681b50f492eb008b97231708612fef463 Mon Sep 17 00:00:00 2001 From: Victor Schmidt Date: Tue, 5 Mar 2024 10:10:48 -0500 Subject: [PATCH 088/106] improve init docs --- gflownet/evaluator/__init__.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index 45ee4d00f..cefc88f6a 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -1,6 +1,22 @@ """ -Create a new evaluator by subclassing this class and extending the :meth:`eval` -method to add more metrics and plots. +An ``Evaluator`` is a class that is used to compute metrics and plots. +It serves two complementary purposes: + +1. It is used to evaluate the performance of the agent during training and to log the + results. +2. It is intended to be used to evaluate the performance of a trained agent, from a + directory containing the agent's checkpoints for instance. + +.. note:: + + This dual use explains some seaminlgy redundant methods / or arguments to methods. + + For instance in :`gflownet.evaluator.abstract.AbstractEvaluator.eval` the + ``metrics`` argument will never change during the training of a GflowNet (it will + always be ``None``, *i.e.* inherited from the config file) but a user looking to + evaluate a trained agent may want to specify different metrics to compute without + altering the config file. + .. important:: From 4bed143f706e3adf355b6354c4deab839d41c519 Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 5 Mar 2024 10:54:06 -0500 Subject: [PATCH 089/106] outline --- gflownet/evaluator/__init__.py | 43 ++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index cefc88f6a..377651080 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -11,7 +11,7 @@ This dual use explains some seaminlgy redundant methods / or arguments to methods. - For instance in :`gflownet.evaluator.abstract.AbstractEvaluator.eval` the + For instance in :meth:`gflownet.evaluator.abstract.AbstractEvaluator.eval` the ``metrics`` argument will never change during the training of a GflowNet (it will always be ``None``, *i.e.* inherited from the config file) but a user looking to evaluate a trained agent may want to specify different metrics to compute without @@ -111,10 +111,39 @@ class methods to instantiate an evaluator. Implementing your own evaluator ------------------------------- +In general, you will inherit from :class:`~gflownet.evaluator.base.BaseEvaluator` and +override the following methods: + +* ``define_new_metrics``: define new metrics and associated requirements. +* ``eval``: compute the metrics and return them as a ``dict``: + `` {"metrics": {metric_name: metric_value}, "data": {str: Any}}``. +* ``plot``: return a ``dict`` of figures as ``{figure_title: figure}``. + +By default, the training loop will call the ``eval_and_log`` method which itself calls +the ``eval`` method to log the metrics, and the ``plot`` method to log the figures: + +..code-block:: python + + def eval_and_log(self, metrics=None, **plot_kwargs): + results = self.eval(metrics=metrics) + for m, v in results["metrics"].items(): + setattr(self.gfn, m, v) + + mertics_to_log = { + METRICS[k]["display_name"]: v for k, v in results["metrics"].items() + } + + figs = self.plot(**results["data"]) + + self.logger.log_metrics(mertics_to_log, it, self.gfn.use_context) + self.logger.log_plots(figs, it, use_context=self.gfn.use_context) + +Example implementation: + .. code-block:: python # gflownet/evaluator/my_evaluator.py - from gflownet.evaluator.base import BaseEvaluator, METRICS, ALL_REQS + from gflownet.evaluator.base import BaseEvaluator class MyEvaluator(BaseEvaluator): def define_new_metrics(self): @@ -122,12 +151,12 @@ def define_new_metrics(self): This method is called when the class is instantiated and is used to update the global METRICS and ALL_REQS variables. ''' - return { - "your_metric": { - "display_name": "My custom metric", - "requirements": ["density", "new_req"], - }, + my_metrics = super().define_new_metrics() + my_metrics["new_metric"] = { + "display_name": "My custom metric", + "requirements": ["density", "new_req"], } + return my_metrics def my_custom_metric(self, some, arguments): From 66e457d16c99e32167a159e605234f8fd04f9267 Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 5 Mar 2024 10:56:15 -0500 Subject: [PATCH 090/106] typo --- gflownet/evaluator/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index 377651080..c77b75df3 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -122,7 +122,7 @@ class methods to instantiate an evaluator. By default, the training loop will call the ``eval_and_log`` method which itself calls the ``eval`` method to log the metrics, and the ``plot`` method to log the figures: -..code-block:: python +.. code-block:: python def eval_and_log(self, metrics=None, **plot_kwargs): results = self.eval(metrics=metrics) From aed7110aaa2853ab21791472cfb4c6b77543d3b8 Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 5 Mar 2024 11:09:02 -0500 Subject: [PATCH 091/106] add note --- gflownet/evaluator/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index c77b75df3..e7191023b 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -116,7 +116,7 @@ class methods to instantiate an evaluator. * ``define_new_metrics``: define new metrics and associated requirements. * ``eval``: compute the metrics and return them as a ``dict``: - `` {"metrics": {metric_name: metric_value}, "data": {str: Any}}``. + ``{"metrics": {metric_name: metric_value}, "data": {str: Any}}``. * ``plot``: return a ``dict`` of figures as ``{figure_title: figure}``. By default, the training loop will call the ``eval_and_log`` method which itself calls @@ -314,6 +314,12 @@ def eval(self, metrics=None, **plot_kwargs): period: 1000 +.. note:: + + In general, you should not override the ``make_requirements`` or ``make_metrics`` + methods. They should be used as-is in your ``eval`` method (or any other) to decide + which metrics and plots to compute. + In the previous example, the ``define_new_metrics`` method is used to define new metrics and associated requirements. It will be called when the ``MyEvaluator`` class is instantiated, in the init of From 918fc7c2f62cf679abf298ebbd59c45c006b9155 Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Tue, 5 Mar 2024 14:12:49 -0500 Subject: [PATCH 092/106] docs `plot` and `eval_top_k` --- gflownet/evaluator/abstract.py | 41 ++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 71acfcfcd..76cf3a280 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -547,6 +547,25 @@ def should_checkpoint(self, step): @abstractmethod def plot(self, **kwargs): + """ + The main method to plot results. + + Will be called by the :meth:`eval_and_log` method to plot the results + of the evaluation. + Will be passed the results of the :meth:`eval` method: + + .. code-block:: python + + # in eval_and_log + results = self.eval(metrics=metrics) + figs = self.plot(**results["data"]) + + Returns + ------- + dict + Dictionary of figures to log, with the figure names as keys and the figures + as values. + """ pass @abstractmethod @@ -585,6 +604,28 @@ def eval(self, metrics=None, **plot_kwargs): @abstractmethod def eval_top_k(self, it): + """ + Evaluate the ``GFlowNetAgent``'s top k samples performance. + + Classes extending this abstract class should implement this method. + + Parameters + ---------- + it : int + Current iteration step. + + Returns + ------- + dict + Dictionary with the following keys schema: + .. code-block:: python + + { + "metrics": {str: float}, + "figs": {str: plt.Figure}, + "summary": {str: float}, + } + """ pass def eval_and_log(self, it, metrics=None): From fd3fabda94d189f1940c3af1d0ddd0a058c662e7 Mon Sep 17 00:00:00 2001 From: vict0rsch Date: Thu, 7 Mar 2024 11:26:44 -0500 Subject: [PATCH 093/106] test code-include --- docs/conf.py | 1 + docs/requirements-docs.txt | 1 + gflownet/evaluator/abstract.py | 8 ++++++++ 3 files changed, 10 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index c3199034b..0a275b706 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -47,6 +47,7 @@ "sphinx_design", "sphinx_copybutton", "sphinxext.opengraph", + "code_include.extension", ] # Add any paths that contain templates here, relative to this directory. diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index e07cee7a5..4ebb951a1 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -9,3 +9,4 @@ sphinx-copybutton==0.5.1 sphinx-hoverxref==1.3.0 sphinxext-opengraph==0.8.2 sphinx-autoapi==3.0.0 +sphinx-code-include==1.1.1 \ No newline at end of file diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 76cf3a280..1b2349c25 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -16,6 +16,14 @@ which will be called by the :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` method: +.. code-include :: :meth:`gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` + +.. code-include :: :func:`gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` + +.. code-include :: :class:`gflownet.gflownet.abstract.AbstractEvaluator` + +.. code-include :: :func:`gflownet.utils.common.gflownet_from_config` + .. code-block:: python def eval_and_log(self, it, metrics=None): From ece173ebda511dd94f2257184bad04ab9b7de450 Mon Sep 17 00:00:00 2001 From: carriepl <832811+carriepl@users.noreply.github.com> Date: Thu, 30 May 2024 13:14:48 -0400 Subject: [PATCH 094/106] Update gflownet/evaluator/__init__.py Co-authored-by: Alex --- gflownet/evaluator/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index e7191023b..0f15b6fd7 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -1,5 +1,5 @@ """ -An ``Evaluator`` is a class that is used to compute metrics and plots. +An ``Evaluator`` is a class that is used to compute metrics and generate plots. It serves two complementary purposes: 1. It is used to evaluate the performance of the agent during training and to log the From 93baba4d6722e759350bf8afc617a1741b203b22 Mon Sep 17 00:00:00 2001 From: carriepl <832811+carriepl@users.noreply.github.com> Date: Thu, 30 May 2024 13:15:28 -0400 Subject: [PATCH 095/106] Apply suggestions from code review - Improve docstrings Co-authored-by: Alex --- gflownet/evaluator/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index 0f15b6fd7..d288f0399 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -2,9 +2,9 @@ An ``Evaluator`` is a class that is used to compute metrics and generate plots. It serves two complementary purposes: -1. It is used to evaluate the performance of the agent during training and to log the +1. Evaluate the performance of the agent during training and to log the results. -2. It is intended to be used to evaluate the performance of a trained agent, from a +2. Evaluate the performance of a trained agent, from a directory containing the agent's checkpoints for instance. .. note:: From 43ef77e44c51d0ebb577370b997aac058ec96434 Mon Sep 17 00:00:00 2001 From: carriepl <832811+carriepl@users.noreply.github.com> Date: Thu, 30 May 2024 13:15:40 -0400 Subject: [PATCH 096/106] Update gflownet/evaluator/__init__.py Co-authored-by: Alex --- gflownet/evaluator/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/__init__.py b/gflownet/evaluator/__init__.py index d288f0399..e0e070fcc 100644 --- a/gflownet/evaluator/__init__.py +++ b/gflownet/evaluator/__init__.py @@ -48,7 +48,7 @@ class methods to instantiate an evaluator. Basic concepts -------------- -The evaluator is used to compute metrics and plots. It is used to evaluate the +The evaluator is used to compute metrics and generate plots. It is used to evaluate the performance of the agent during training and to log the results. It is also intended to be used to evaluate the performance of a trained agent. From 7a1f97af58124bfeefba787669826c1f3f926e98 Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Fri, 31 May 2024 10:26:29 -0400 Subject: [PATCH 097/106] Complete GFlowNetAgent docstring --- gflownet/gflownet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 172b9ccd0..4311d2659 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -55,7 +55,7 @@ def __init__( **kwargs, ): """ - Main class of this repository. Handles + Main class of this repository. Handles the training logic for a GFlowNet model. Parameters ---------- From 0b0b0fc12f28c4d6bcb2ea0d499342cd62ccafd4 Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Fri, 31 May 2024 10:30:24 -0400 Subject: [PATCH 098/106] Remove unused variable --- scripts/eval_gflownet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/eval_gflownet.py b/scripts/eval_gflownet.py index 1fda71ebd..d3f9dd7d2 100644 --- a/scripts/eval_gflownet.py +++ b/scripts/eval_gflownet.py @@ -175,7 +175,7 @@ def main(args): print("output_dir: ", str(output_dir)) output_dir.mkdir(parents=True, exist_ok=True) - for k, (figname, fig) in enumerate(eval_results["figs"].items()): + for figname, fig in eval_results["figs"].items(): output_fig = output_dir / (path_compatible(figname) + ".pdf") if fig is not None: fig.savefig(output_fig, bbox_inches="tight") From 800314ce4065070ad8a54750514a4c8a0e32df3f Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Tue, 4 Jun 2024 09:22:01 -0400 Subject: [PATCH 099/106] Fix pytest filename conflicts --- tests/gflownet/evaluator/__init__.py | 0 tests/gflownet/proxy/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/gflownet/evaluator/__init__.py create mode 100644 tests/gflownet/proxy/__init__.py diff --git a/tests/gflownet/evaluator/__init__.py b/tests/gflownet/evaluator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/gflownet/proxy/__init__.py b/tests/gflownet/proxy/__init__.py new file mode 100644 index 000000000..e69de29bb From a9bea9a069996eeb5d19d67097332415b556a707 Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Tue, 4 Jun 2024 14:09:28 -0400 Subject: [PATCH 100/106] Re-integrate changes from main lost in merge --- config/evaluator/base.yaml | 2 ++ gflownet/evaluator/base.py | 45 ++++++++++++++++++++++++++------------ gflownet/gflownet.py | 33 ++++++++++++++++++++++++++-- gflownet/utils/buffer.py | 1 - gflownet/utils/common.py | 8 ++++--- 5 files changed, 69 insertions(+), 20 deletions(-) diff --git a/config/evaluator/base.yaml b/config/evaluator/base.yaml index 03af9450f..c429fe808 100644 --- a/config/evaluator/base.yaml +++ b/config/evaluator/base.yaml @@ -16,6 +16,8 @@ logprobs_batch_size: 100 logprobs_bootstrap_size: 10000 # Maximum number of test data points to compute log likelihood probs. max_data_logprobs: 1e5 +# Number of points to obtain a grid to estimate the reward density +n_grid: 40000 train_log_period: 1 checkpoints_period: 1000 # List of metrics as per gflownet/eval/evaluator.py:METRICS_NAMES diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 7c36786c2..b2d02eca7 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -129,7 +129,10 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): if not gfn_states: # sample states from the current gfn batch = Batch( - env=self.gfn.env, device=self.gfn.device, float_type=self.gfn.float + env=self.gfn.env, + proxy=self.gfn.proxy, + device=self.gfn.device, + float_type=self.gfn.float, ) self.gfn.random_action_prob = 0 t = time.time() @@ -154,7 +157,10 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): # sample random states from uniform actions if not random_states: batch = Batch( - env=self.gfn.env, device=self.gfn.device, float_type=self.gfn.float + env=self.gfn.env, + proxy=self.gfn.proxy, + device=self.gfn.device, + float_type=self.gfn.float, ) self.gfn.random_action_prob = 1.0 print("[eval_top_k] Sampling at random...", end="\r") @@ -264,16 +270,14 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): lp_metrics["mean_probs_std"] = probs_std.mean().item() if "reward_batch" in reqs: - rewards_x_tt = self.gfn.env.reward_batch(x_tt) + rewards_x_tt = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_tt)) if "corr_prob_traj_rewards" in metrics: - rewards_x_tt = self.gfn.env.reward_batch(x_tt) lp_metrics["corr_prob_traj_rewards"] = np.corrcoef( np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt )[0, 1] if "var_logrewards_logp" in metrics: - rewards_x_tt = self.gfn.env.reward_batch(x_tt) lp_metrics["var_logrewards_logp"] = torch.var( torch.log( tfloat( @@ -342,9 +346,11 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): x_sampled = batch.get_terminating_states() if "density_true" in dict_tt: - density_true = dict_tt["density_true"] + density_true = torch2np(dict_tt["density_true"]) else: - rewards = self.gfn.env.reward_batch(x_tt) + rewards = torch2np( + self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_tt)) + ) z_true = rewards.sum() density_true = rewards / z_true with open(self.gfn.buffer.test_pkl, "wb") as f: @@ -361,9 +367,8 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): elif self.gfn.continuous and hasattr(self.gfn.env, "fit_kde"): batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False) assert batch.is_valid() - x_sampled = batch.get_terminating_states() + x_sampled = batch.get_terminating_states(proxy=True) # TODO make it work with conditional env - x_sampled = torch2np(self.gfn.env.states2proxy(x_sampled)) x_tt = torch2np(self.gfn.env.states2proxy(x_tt)) kde_pred = self.gfn.env.fit_kde( x_sampled, @@ -375,8 +380,9 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None): kde_true = dict_tt["kde_true"] else: # Sample from reward via rejection sampling - x_from_reward = self.gfn.env.sample_from_reward(n_samples=self.config.n) - x_from_reward = torch2np(self.gfn.env.states2proxy(x_from_reward)) + x_from_reward = self.gfn.env.states2proxy( + self.gfn.sample_from_reward(n_samples=self.config.n) + ) # Fit KDE with samples from reward kde_true = self.gfn.env.fit_kde( x_from_reward, @@ -547,15 +553,26 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): fig_kde_pred = fig_kde_true = fig_reward_samples = None if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None: + (sample_space_batch, rewards_sample_space) = ( + self.gfn.get_sample_space_and_reward() + ) fig_reward_samples = self.gfn.env.plot_reward_samples( - x_sampled, **plot_kwargs + x_sampled, + sample_space_batch, + rewards_sample_space, + **plot_kwargs, ) if hasattr(self.gfn.env, "plot_kde"): + sample_space_batch, _ = self.gfn.get_sample_space_and_reward() if kde_pred is not None: - fig_kde_pred = self.gfn.env.plot_kde(kde_pred, **plot_kwargs) + fig_kde_pred = self.gfn.env.plot_kde( + sample_space_batch, kde_pred, **plot_kwargs + ) if kde_true is not None: - fig_kde_true = self.gfn.env.plot_kde(kde_true, **plot_kwargs) + fig_kde_true = self.gfn.env.plot_kde( + sample_space_batch, kde_true, **plot_kwargs + ) return { "True reward and GFlowNet samples": fig_reward_samples, diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 87d7da7d5..9bc9d0c47 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -169,7 +169,6 @@ def __init__( **buffer, env=self.env, proxy=self.proxy, - make_train_test=not sample_only, logger=logger, ) # Train set statistics and reward normalization constant @@ -1212,6 +1211,36 @@ def train(self): if self.use_context is False: self.logger.end() + def get_sample_space_and_reward(self): + """ + Returns samples representative of the env state space with their rewards + + Returns + ------- + sample_space_batch : tensor + Repressentative terminating states for the environment + rewards_sample_space : tensor + Rewards associated with the tates in sample_space_batch + """ + if not hasattr(self, "sample_space_batch"): + if hasattr(self.env, "get_all_terminating_states"): + self.sample_space_batch = self.env.get_all_terminating_states() + elif hasattr(self.env, "get_grid_terminating_states"): + self.sample_space_batch = self.env.get_grid_terminating_states( + self.evaluator.config.n_grid + ) + else: + raise NotImplementedError( + "In order to obtain representative terminating states, the " + "environment must implement either get_all_terminating_states() " + "or get_grid_terminating_states()" + ) + self.sample_space_batch = self.env.states2proxy(self.sample_space_batch) + if not hasattr(self, "rewards_sample_space"): + self.rewards_sample_space = self.proxy.rewards(self.sample_space_batch) + + return self.sample_space_batch, self.rewards_sample_space + # TODO: implement other proposal distributions # TODO: rethink whether it is needed to convert to reward def sample_from_reward( @@ -1243,7 +1272,7 @@ def sample_from_reward( format. """ samples_final = [] - max_reward = self.get_max_reward() + max_reward = self.proxy.get_max_reward() while len(samples_final) < n_samples: if proposal_distribution == "uniform": # TODO: sample only the remaining number of samples diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index f214d601b..ac1447342 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -20,7 +20,6 @@ def __init__( self, env, proxy, - make_train_test=False, replay_capacity=0, output_csv=None, data_path=None, diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 063133383..1e14bd65c 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -251,12 +251,13 @@ def gflownet_from_config(config): ) # The proxy is passed to env and used for computing rewards - env = instantiate( + env_maker = instantiate( config.env, - proxy=proxy, device=config.device, float_precision=config.float_precision, + _partial_=True, ) + env = env_maker() # The evaluator is used to compute metrics and plots evaluator = instantiate(config.evaluator) @@ -296,7 +297,8 @@ def gflownet_from_config(config): config.gflownet, device=config.device, float_precision=config.float_precision, - env=env, + env_maker=env_maker, + proxy=proxy, forward_policy=forward_policy, backward_policy=backward_policy, state_flow=state_flow, From 81e474539be23f7de53dcdcab1de388c62573796 Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Tue, 4 Jun 2024 16:00:12 -0400 Subject: [PATCH 101/106] Update comments --- main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/main.py b/main.py index 8f3d3628b..c3a4a6780 100644 --- a/main.py +++ b/main.py @@ -33,14 +33,13 @@ def main(config): # Logger logger = hydra.utils.instantiate(config.logger, config, _recursive_=False) - # The proxy is required for scoring + # The proxy is required by the GFlowNetAgent for computing rewards proxy = hydra.utils.instantiate( config.proxy, device=config.device, float_precision=config.float_precision, ) - # The proxy is passed to env and used for computing rewards # Using Hydra's partial instantiation, see: # https://hydra.cc/docs/advanced/instantiate_objects/overview/#partial-instantiation env_maker = hydra.utils.instantiate( From 689ed6ab68c6e7455f007f1bd3743b155badfc88 Mon Sep 17 00:00:00 2001 From: Pierre Luc Carrier Date: Wed, 5 Jun 2024 10:29:39 -0400 Subject: [PATCH 102/106] Improve comments in Logger --- gflownet/utils/buffer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index ac1447342..5ebeb51c4 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -51,7 +51,8 @@ def __init__( self.test_pkl = None self.save_replay() - # Define train and test data sets + + # Define train data set if train is not None and "type" in train: self.train_type = train.type else: @@ -83,6 +84,7 @@ def __init__( ) self.train_pkl = None + # Define test data set if test is not None and "type" in test: self.test_type = test.type else: @@ -109,6 +111,7 @@ def __init__( """ ) self.test_pkl = None + # Compute buffer statistics if self.train is not None: ( From d689270be659e258bc8686e078888bc939e85ae6 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 5 Jun 2024 16:42:57 -0400 Subject: [PATCH 103/106] Adjust sanity check runs (CTorus) to Evaluator config --- mila/dev/sanity_check_runs.md | 6 +++--- mila/dev/sanity_check_runs.yaml | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mila/dev/sanity_check_runs.md b/mila/dev/sanity_check_runs.md index 8b36b4799..b138c8c94 100644 --- a/mila/dev/sanity_check_runs.md +++ b/mila/dev/sanity_check_runs.md @@ -116,17 +116,17 @@ python mila/launch.py --conda_env= user=$USER env=tetris proxy=t `salloc`: ```bash -python main.py user=$USER +experiments=icml23/ctorus logger.test.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python main.py user=$USER +experiments=icml23/ctorus evaluator.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True ``` `sbatch` with `virtualenv`: ```bash -python mila/launch.py --venv= --template=mila/sbatch/template-venv.sh user=$USER +experiments=icml23/ctorus logger.test.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python mila/launch.py --venv= --template=mila/sbatch/template-venv.sh user=$USER +experiments=icml23/ctorus evaluator.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True ``` `sbatch` with `conda`: ```bash -python mila/launch.py --conda_env= user=$USER +experiments=icml23/ctorus logger.test.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python mila/launch.py --conda_env= user=$USER +experiments=icml23/ctorus evaluator.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True ``` diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index cad19cf90..d82125f9a 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -106,6 +106,5 @@ jobs: job_name: sanity-ctorus script: +experiments: icml23/ctorus - logger: - test: - period: 500 + evalutor: + period: 500 From 4f4fcadafda5f6e21e1e1d9be10f18e1a2fae5df Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 5 Jun 2024 17:26:53 -0400 Subject: [PATCH 104/106] Fix typo --- gflownet/evaluator/abstract.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gflownet/evaluator/abstract.py b/gflownet/evaluator/abstract.py index 1b2349c25..d3da5aff1 100644 --- a/gflownet/evaluator/abstract.py +++ b/gflownet/evaluator/abstract.py @@ -31,13 +31,13 @@ def eval_and_log(self, it, metrics=None): for m, v in results["metrics"].items(): setattr(self.gfn, m, v) - mertics_to_log = { + metrics_to_log = { METRICS[k]["display_name"]: v for k, v in results["metrics"].items() } figs = self.plot(**results["data"]) - self.logger.log_metrics(mertics_to_log, it, self.gfn.use_context) + self.logger.log_metrics(metrics_to_log, it, self.gfn.use_context) self.logger.log_plots(figs, it, use_context=self.gfn.use_context) See :mod:`gflownet.evaluator` for a full-fledged example and @@ -655,13 +655,13 @@ def eval_and_log(self, it, metrics=None): for m, v in results["metrics"].items(): setattr(self.gfn, m, v) - mertics_to_log = { + metrics_to_log = { METRICS[k]["display_name"]: v for k, v in results["metrics"].items() } figs = self.plot(**results["data"]) - self.logger.log_metrics(mertics_to_log, it, self.gfn.use_context) + self.logger.log_metrics(metrics_to_log, it, self.gfn.use_context) self.logger.log_plots(figs, it, use_context=self.gfn.use_context) def eval_and_log_top_k(self, it): From 25a45e243827c78689bb5ccabc79838e0a3de4d3 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 5 Jun 2024 19:37:56 -0400 Subject: [PATCH 105/106] Evaluator: add samples_topk to plot(); add TODOs --- gflownet/envs/tetris.py | 3 ++- gflownet/evaluator/base.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index 9cd2f722b..bf927de2b 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -534,6 +534,7 @@ def plot_samples_topk( k_top: int = 10, n_rows: int = 2, dpi: int = 150, + **kwargs, ): """ Plot tetris boards of top K samples. @@ -543,7 +544,7 @@ def plot_samples_topk( samples : list List of terminating states sampled from the policy. rewards : list - List of terminating states. + Rewards of the samples. k_top : int The number of samples that will be included in the plot. The k_top samples with the highest reward are selected. diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index b2d02eca7..9085d9f84 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -98,6 +98,9 @@ def define_new_metrics(self): }, } + # TODO: this method will most likely crash if used (top_k_period != -1) because + # self.gfn.env.top_k_metrics_and_plots still makes use of env.proxy. + # Re-implementing this wil require a non-trivial amount of work. @torch.no_grad() def eval_top_k(self, it, gfn_states=None, random_states=None): """ @@ -124,6 +127,7 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): do_random = it // self.logger.test.top_k_period == 1 duration = None summary = {} + # TODO: Why deepcopy? prob = copy.deepcopy(self.random_action_prob) print() if not gfn_states: @@ -517,9 +521,12 @@ def plot( Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which will be logged. - By default, this method will call the `plot_reward_samples` method of the - GFlowNetAgent's environment, and the `plot_kde` method of the GFlowNetAgent's - environment if it exists for both the `kde_pred` and `kde_true` arguments. + By default, this method will call the following methods of the GFlowNetAgent's + environment if they exist: + + - `plot_reward_samples` + - `plot_kde` (for both the `kde_pred` and `kde_true` arguments) + - `plot_samples_topk` Extend this method to add more plots: @@ -574,8 +581,24 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): sample_space_batch, kde_true, **plot_kwargs ) + # TODO: consider moving this to eval_top_k once fixed + if hasattr(self.gfn.env, "plot_samples_topk"): + if x_sampled is None: + batch, _ = self.gfn.sample_batch( + n_forward=self.config.n_top_k, train=False + ) + x_sampled = batch.get_terminating_states() + rewards = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_sampled)) + fig_samples_topk = self.gfn.env.plot_samples_topk( + x_sampled, + rewards, + self.config.top_k, + **plot_kwargs, + ) + return { "True reward and GFlowNet samples": fig_reward_samples, "GFlowNet KDE Policy": fig_kde_pred, "Reward KDE": fig_kde_true, + "Samples TopK": fig_samples_topk, } From 6efaa1963f5af9ed5ea6643723f214d11e3aaa26 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 5 Jun 2024 19:38:33 -0400 Subject: [PATCH 106/106] Update evaluator config of Tetris sanity runs --- mila/dev/sanity_check_runs.md | 12 ++++++------ mila/dev/sanity_check_runs.yaml | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mila/dev/sanity_check_runs.md b/mila/dev/sanity_check_runs.md index b138c8c94..c3409e562 100644 --- a/mila/dev/sanity_check_runs.md +++ b/mila/dev/sanity_check_runs.md @@ -73,19 +73,19 @@ python mila/launch.py --conda_env= user=$USER env=grid env.lengt `salloc`: ```bash -python main.py user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=trajectorybalance device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python main.py user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=trajectorybalance device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True evaluator.top_k=10 evaluator.n_top_k=100 ``` `sbatch` with `virtualenv`: ```bash -python mila/launch.py --venv= --template=mila/sbatch/template-venv.sh user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=trajectorybalance device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python mila/launch.py --venv= --template=mila/sbatch/template-venv.sh user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=trajectorybalance device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True evaluator.top_k=10 evaluator.n_top_k=100 ``` `sbatch` with `conda`: ```bash -python mila/launch.py --conda_env= user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=trajectorybalance device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python mila/launch.py --conda_env= user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=trajectorybalance device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True evaluator.top_k=10 evaluator.n_top_k=100 ``` ### Flow Matching loss @@ -93,19 +93,19 @@ python mila/launch.py --conda_env= user=$USER env=tetris proxy=t `salloc`: ```bash -python main.py user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=flowmatch device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python main.py user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=flowmatch device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True evaluator.top_k=10 evaluator.n_top_k=100 ``` `sbatch` with `virtualenv`: ```bash -python mila/launch.py --venv= --template=mila/sbatch/template-venv.sh user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=flowmatch device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python mila/launch.py --venv= --template=mila/sbatch/template-venv.sh user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=flowmatch device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True evaluator.top_k=10 evaluator.n_top_k=100 ``` `sbatch` with `conda`: ```bash -python mila/launch.py --conda_env= user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=flowmatch device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python mila/launch.py --conda_env= user=$USER env=tetris proxy=tetris env.width=5 env.height=10 gflownet=flowmatch device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True evaluator.top_k=10 evaluator.n_top_k=100 ``` ## Continuous Torus as in Lahlou et al (ICML 2023) diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index d82125f9a..636abeda2 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -52,6 +52,9 @@ jobs: height: 10 gflownet: flowmatch proxy: tetris + evaluator: + top_k: 10 + n_top_k: 100 - slurm: job_name: sanity-tetris-tb script: @@ -61,6 +64,9 @@ jobs: height: 10 gflownet: trajectorybalance proxy: tetris + evaluator: + top_k: 10 + n_top_k: 100 # Mini-Tetris - slurm: job_name: sanity-mintetris-fm @@ -75,6 +81,9 @@ jobs: __value__: tetris reward_function: exponential gflownet: flowmatch + evaluator: + top_k: 10 + n_top_k: 100 - slurm: job_name: sanity-mintetris-tb script: @@ -88,6 +97,9 @@ jobs: __value__: tetris reward_function: exponential gflownet: trajectorybalance + evaluator: + top_k: 10 + n_top_k: 100 - slurm: job_name: sanity-mintetris-fl script: @@ -101,6 +113,9 @@ jobs: __value__: tetris reward_function: exponential gflownet: forwardlooking + evaluator: + top_k: 10 + n_top_k: 100 # Ctorus - slurm: job_name: sanity-ctorus