Skip to content

Commit

Permalink
Small fixes in the evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Nov 27, 2023
1 parent 31a4239 commit 15e77fa
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
3 changes: 0 additions & 3 deletions gflownet/utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,6 @@ def make_data_set(self, config):
f", but only {n_samples_new} are valid according to the "
"environment settings. Invalid samples have been discarded."
)
n_max = 100
samples = samples[:n_max]
print(f"Only the first {n_max} samples will be kept in the data.")
print("Remember to write a function to normalise the data in code")
print("Max number of elements in data set has to match config")
print("Actually, write a function that contrasts the stats")
Expand Down
15 changes: 9 additions & 6 deletions gflownet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,16 @@ def resolve_path(path: str) -> Path:
return Path(expandvars(str(path))).expanduser().resolve()


def find_latest_checkpoint(ckpt_dir, pattern):
final = list(ckpt_dir.glob(f"{pattern}*final*"))
def find_latest_checkpoint(ckpt_dir, ckpt_name):
ckpt_name = Path(ckpt_name).stem
final = list(ckpt_dir.glob(f"{ckpt_name}*final*"))
if len(final) > 0:
return final[0]
ckpts = list(ckpt_dir.glob(f"{pattern}*"))
ckpts = list(ckpt_dir.glob(f"{ckpt_name}*"))
if not ckpts:
raise ValueError(f"No checkpoints found in {ckpt_dir} with pattern {pattern}")
raise ValueError(
f"No final checkpoints found in {ckpt_dir} with pattern {ckpt_name}*final*"
)
return sorted(ckpts, key=lambda f: float(f.stem.split("iter")[1]))[-1]


Expand Down Expand Up @@ -175,12 +178,12 @@ def load_gflow_net_from_run_path(
# -------------------------------

ckpt = [f for f in run_path.rglob(config.logger.logdir.ckpts) if f.is_dir()][0]
forward_final = find_latest_checkpoint(ckpt, "pf")
forward_final = find_latest_checkpoint(ckpt, config.policy.forward.checkpoint)
gflownet.forward_policy.model.load_state_dict(
torch.load(forward_final, map_location=set_device(device))
)
try:
backward_final = find_latest_checkpoint(ckpt, "pb")
backward_final = find_latest_checkpoint(ckpt, config.policy.backward.checkpoint)
gflownet.backward_policy.model.load_state_dict(
torch.load(backward_final, map_location=set_device(device))
)
Expand Down

0 comments on commit 15e77fa

Please sign in to comment.