Skip to content

Commit

Permalink
Remove references to crystals and oracle in eval_gflownet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Feb 21, 2024
1 parent 967c5b8 commit cc90b5c
Showing 1 changed file with 3 additions and 39 deletions.
42 changes: 3 additions & 39 deletions scripts/eval_gflownet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Computes evaluation metrics and plots from a pre-trained GFlowNet model.
"""

import pickle
import shutil
import sys
Expand All @@ -13,8 +14,6 @@

sys.path.append(str(Path(__file__).resolve().parent.parent))

from crystalrandom import generate_random_crystals

from gflownet.gflownet import GFlowNetAgent
from gflownet.utils.common import load_gflow_net_from_run_path
from gflownet.utils.policy import parse_policy_config
Expand Down Expand Up @@ -69,11 +68,6 @@ def add_args(parser):
action="store_true",
help="Sample from an untrained GFlowNet",
)
parser.add_argument(
"--random_crystals",
action="store_true",
help="Sample crystals uniformly, without constraints",
)
parser.add_argument("--device", default="cpu", type=str)
return parser

Expand Down Expand Up @@ -178,6 +172,7 @@ def main(args):
# ----- Sample GFlowNet -----
# ------------------------------------------

# Handle output directory
output_dir = base_dir / "eval" / "samples"
output_dir.mkdir(parents=True, exist_ok=True)
tmp_dir = output_dir / "tmp"
Expand All @@ -193,7 +188,7 @@ def main(args):
):
batch, times = gflownet.sample_batch(n_forward=bs, train=False)
x_sampled = batch.get_terminating_states(proxy=True)
energies = env.oracle(x_sampled)
energies = env.proxy(x_sampled)
x_sampled = batch.get_terminating_states()
df = pd.DataFrame(
{
Expand All @@ -218,37 +213,6 @@ def main(args):
if "y" in input("Delete temporary files? (y/n)"):
shutil.rmtree(tmp_dir)

# ------------------------------------
# ----- Sample random crystals -----
# ------------------------------------

# Sample random crystals uniformly without constraints
if args.random_crystals and args.n_samples > 0 and args.n_samples <= 1e5:
print(f"Sampling {args.n_samples} random crystals without constraints...")
x_sampled = generate_random_crystals(
n_samples=args.n_samples,
elements=config.env.composition_kwargs.elements,
min_elements=2,
max_elements=5,
max_atoms=config.env.composition_kwargs.max_atoms,
max_atom_i=config.env.composition_kwargs.max_atom_i,
space_groups=config.env.space_group_kwargs.space_groups_subset,
min_length=0.0,
max_length=1.0,
min_angle=0.0,
max_angle=1.0,
)
energies = env.oracle(env.states2proxy(x_sampled))
df = pd.DataFrame(
{
"readable": [env.state2readable(x) for x in x_sampled],
"energies": energies.tolist(),
}
)
df.to_csv(output_dir / "randomcrystals_samples.csv")
dct = {"x": x_sampled, "energy": energies.tolist()}
pickle.dump(dct, open(output_dir / "randomcrystals_samples.pkl", "wb"))


if __name__ == "__main__":
parser = ArgumentParser()
Expand Down

0 comments on commit cc90b5c

Please sign in to comment.