Skip to content

Commit

Permalink
Merge pull request #230 from GFNOrg/recalculate
Browse files Browse the repository at this point in the history
Default behavior: recalculate logprobs
  • Loading branch information
saleml authored Jan 22, 2025
2 parents 7f03681 + 0cacc50 commit bbcf21f
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 13 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n=16)
trajectories = sampler.sample_trajectories(env=env, n=16, save_logprobs=True) # The save_logprobs=True makes on-policy training faster
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
Expand Down Expand Up @@ -152,7 +152,7 @@ logF_estimator = ScalarEstimator(module=module_logF, preprocessor=env.preprocess
gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF, lamda=0.9)

# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy
sampler = Sampler(estimator=pf_estimator)

# Different policy parameters can have their own LR.
# Log F gets dedicated learning rate (typically higher).
Expand All @@ -161,7 +161,10 @@ optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n=16)
# We are going to sample trajectories off policy, by tempering the distribution.
# We should not save the sampling logprobs, as we are not using them for training.
# We should save the estimator outputs to make training faster.
trajectories = sampler.sample_trajectories(env=env, n=16, save_logprobs=False, save_estimator_outputs=True, temperature=1.5)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def sample_trajectories(
self,
env: Env,
n: int,
save_logprobs: bool = True,
save_logprobs: bool = False,
save_estimator_outputs: bool = False,
) -> Trajectories:
"""Sample a specific number of complete trajectories.
Expand Down Expand Up @@ -93,7 +93,7 @@ def sample_trajectories(
env: Env,
n: int,
conditioning: torch.Tensor | None = None,
save_logprobs: bool = True,
save_logprobs: bool = False,
save_estimator_outputs: bool = False,
**policy_kwargs: Any,
) -> Trajectories:
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def sample_trajectories(
env: Env,
n: int,
conditioning: torch.Tensor | None = None,
save_logprobs: bool = True,
save_logprobs: bool = False,
save_estimator_outputs: bool = False,
**policy_kwargs: Any,
) -> Trajectories:
Expand Down
8 changes: 4 additions & 4 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def sample_actions(
states: States,
conditioning: torch.Tensor | None = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
save_logprobs: bool = False,
**policy_kwargs: Any,
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]:
"""Samples actions from the given states.
Expand Down Expand Up @@ -104,7 +104,7 @@ def sample_trajectories(
states: Optional[States] = None,
conditioning: Optional[torch.Tensor] = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
save_logprobs: bool = False,
**policy_kwargs: Any,
) -> Trajectories:
"""Sample trajectories sequentially.
Expand Down Expand Up @@ -296,7 +296,7 @@ def local_search(
trajectories: Trajectories,
conditioning: torch.Tensor | None = None,
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
save_logprobs: bool = False,
back_steps: torch.Tensor | None = None,
back_ratio: float | None = None,
use_metropolis_hastings: bool = True,
Expand Down Expand Up @@ -456,7 +456,7 @@ def sample_trajectories(
states: Optional[States] = None,
conditioning: Optional[torch.Tensor] = None,
save_estimator_outputs: bool = False, # FIXME: currently not work when this is True
save_logprobs: bool = True, # TODO: Support save_logprobs=True
save_logprobs: bool = False, # TODO: Support save_logprobs=True
n_local_search_loops: int = 0,
back_steps: torch.Tensor | None = None,
back_ratio: float | None = None,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/examples/train_hypergrid_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def main(args):
trajectories = sampler.sample_trajectories(
env,
n=args.batch_size,
save_logprobs=False,
save_estimator_outputs=True,
save_logprobs=True,
save_estimator_outputs=False,
epsilon=args.epsilon,
)
visited_terminating_states.extend(trajectories.last_states)
Expand Down
2 changes: 1 addition & 1 deletion tutorials/examples/train_ising.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def ising_n_to_ij(L, n):
env,
n=8,
save_estimator_outputs=False,
save_logprobs=True,
save_logprobs=False,
)
training_samples = gflownet.to_training_samples(trajectories)
optimizer.zero_grad()
Expand Down

0 comments on commit bbcf21f

Please sign in to comment.