Skip to content

Commit

Permalink
conditional script, black, isort
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandraVolokhova committed Jul 15, 2024
1 parent 00e3f52 commit b626451
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 43 deletions.
8 changes: 4 additions & 4 deletions gflownet/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class methods to instantiate an evaluator.


class BaseEvaluator(AbstractEvaluator):

def __init__(self, gfn_agent=None, **config):
"""
Base evaluator class for GFlowNetAgent.
Expand Down Expand Up @@ -560,9 +559,10 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs):
fig_kde_pred = fig_kde_true = fig_reward_samples = fig_samples_topk = None

if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None:
(sample_space_batch, rewards_sample_space) = (
self.gfn.get_sample_space_and_reward()
)
(
sample_space_batch,
rewards_sample_space,
) = self.gfn.get_sample_space_and_reward()
fig_reward_samples = self.gfn.env.plot_reward_samples(
x_sampled,
sample_space_batch,
Expand Down
24 changes: 12 additions & 12 deletions gflownet/policy/multihead_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,9 @@ def forward(self, x):
logits[indices, self.leaf_index : self.feature_index] = y_leaf
logits[indices, self.eos_index] = y_eos
elif stage == Stage.LEAF:
logits[indices, self.feature_index : self.threshold_index] = (
self.feature_head(batch)
)
logits[
indices, self.feature_index : self.threshold_index
] = self.feature_head(batch)
else:
ks = [Tree.find_active(state) for state in states]
feature_index = torch.Tensor(
Expand All @@ -374,9 +374,9 @@ def forward(self, x):
if self.continuous:
logits[indices, (self.eos_index + 1) :] = head_output
else:
logits[indices, self.threshold_index : self.operator_index] = (
head_output
)
logits[
indices, self.threshold_index : self.operator_index
] = head_output
elif stage == Stage.THRESHOLD:
threshold = torch.Tensor(
[
Expand Down Expand Up @@ -464,14 +464,14 @@ def forward(self, x):
)

if stage == Stage.COMPLETE:
logits[indices, self.operator_index : self.eos_index] = (
self.complete_stage_head(batch)
)
logits[
indices, self.operator_index : self.eos_index
] = self.complete_stage_head(batch)
logits[indices, self.eos_index] = 1.0
elif stage == Stage.LEAF:
logits[indices, self.leaf_index : self.feature_index] = (
self.leaf_stage_head(batch)
)
logits[
indices, self.leaf_index : self.feature_index
] = self.leaf_stage_head(batch)
elif stage == Stage.FEATURE:
logits[indices, self.feature_index : self.threshold_index] = 1.0
elif stage == Stage.THRESHOLD:
Expand Down
6 changes: 3 additions & 3 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,9 +796,9 @@ def get_masks_forward(
masks_invalid_actions_forward_parents[parents_indices == -1] = self.source[
"mask_forward"
]
masks_invalid_actions_forward_parents[parents_indices != -1] = (
masks_invalid_actions_forward[parents_indices[parents_indices != -1]]
)
masks_invalid_actions_forward_parents[
parents_indices != -1
] = masks_invalid_actions_forward[parents_indices[parents_indices != -1]]
return masks_invalid_actions_forward_parents
return masks_invalid_actions_forward

Expand Down
4 changes: 2 additions & 2 deletions scripts/crystal/eval_crystalgflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

sys.path.append(str(Path(__file__).resolve().parent.parent))

from crystalrandom import generate_random_crystals

from gflownet.gflownet import GFlowNetAgent
from gflownet.utils.common import load_gflow_net_from_run_path
from gflownet.utils.policy import parse_policy_config

from crystalrandom import generate_random_crystals


def add_args(parser):
"""
Expand Down
6 changes: 3 additions & 3 deletions scripts/crystal/eval_gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

sys.path.append(str(Path(__file__).resolve().parent.parent.parent))

from crystalrandom import generate_random_crystals_uniform
from hydra.utils import instantiate

from gflownet.gflownet import GFlowNetAgent
from gflownet.utils.common import load_gflow_net_from_run_path, read_hydra_config
from gflownet.utils.policy import parse_policy_config
from hydra.utils import instantiate

from crystalrandom import generate_random_crystals_uniform


def add_args(parser):
Expand Down
21 changes: 4 additions & 17 deletions scripts/crystal/plots_conditional_icml24.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,18 @@
"""
Script for plotting violin plots for conditional sampling.
example cli:
python scripts/crystal/plots_conditional_icml24.py --pkl_path=/home/mila/a/alexandra.volokhova/projects/gflownet-dev/external/data/starling_fe/samples/gfn_samples.pkl --cond_dir_root=/home/mila/a/alexandra.volokhova/projects/gflownet-dev/external/data/starling_fe_conditional
PYTHONPATH=/home/mila/a/alexandra.volokhova/projects/gflownet python scripts/crystal/plots_conditional_icml24.py --pkl_path=/home/mila/a/alexandra.volokhova/projects/gflownet/external/starling_fe/samples/gfn_iter50k_samples.pkl --cond_dir_root=/home/mila/a/alexandra.volokhova/projects/gflownet/external/starling_fe_conditional
"""

import argparse
import datetime
import os
import pickle
import sys
import warnings
from collections import OrderedDict
from pathlib import Path

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import yaml
from mendeleev.fetch import fetch_table
from plots_icml24 import load_gfn_samples, now_to_str
from seaborn_fig2grid import SeabornFig2Grid
from tqdm import tqdm
from plots_icml24 import now_to_str

ROOT = Path(__file__).resolve().parent.parent.parent

Expand Down Expand Up @@ -116,12 +103,12 @@ def load_energies_only(pkl_path, energy_key="energy"):
print(f"Saving plots to {output_path}")

USE_SUPTITLES = not args.no_suptitles
# elements = ['H', 'Li', 'B', 'C', 'N', 'O', 'F', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'K', 'V', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Se']

sdf = load_energies_only(pkl_path=args.pkl_path)
dfs = {"Crystal-GFN (FE)": sdf}
cond_root = Path(args.cond_dir_root)
cond_paths = {
x[-1].upper(): cond_root / x / "eval/samples/gfn_samples.pkl"
x[-1].upper(): cond_root / x / f"gfn_iter50k_samples_restricted_{x[-1]}.pkl"
for x in os.listdir(cond_root)
}
cdfs = {k: load_energies_only(pkl_path=v) for k, v in cond_paths.items()}
Expand Down
1 change: 0 additions & 1 deletion scripts/crystal/sample_uniform_with_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import hydra
import pandas as pd

from crystalrandom import generate_random_crystals_uniform


Expand Down
1 change: 0 additions & 1 deletion scripts/crystal_eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def __init__(
def compute(
self, structures: Iterable[Structure], energies: list, **kwargs
) -> dict:

all_compositions = [s.composition for s in structures]
all_energies = energies
compositions = []
Expand Down

0 comments on commit b626451

Please sign in to comment.