Skip to content

Commit

Permalink
very simple usage example
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Feb 17, 2024
1 parent 7553aac commit 61c08be
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions test.py → tutorials/examples/train_hypergrid_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@
import torch
from tqdm import tqdm

from gfn.gym import HyperGrid
from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils import NeuralNet


torch.manual_seed(0)
exploration_rate = 0.5
learning_rate = 0.0005

env = HyperGrid(ndim=5, height=2)
# Setup the Environment.
env = HyperGrid(
ndim=5,
height=2,
device_str="cuda" if torch.cuda.is_available() else "cpu",
)

# Build the GFlowNet.
module_PF = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions,
Expand All @@ -24,19 +29,25 @@
output_dim=env.n_actions - 1,
torso=module_PF.torso,
)
pf_estimator = DiscretePolicyEstimator(
module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor
)
pb_estimator = DiscretePolicyEstimator(
module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor
)
gflownet = TBGFlowNet(init_logZ=0.0, pf=pf_estimator, pb=pb_estimator, off_policy=True)

pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)
gflownet = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator, off_policy=True)
# Feed pf to the sampler.
sampler = Sampler(estimator=pf_estimator)

# Policy parameters have their own LR.
non_logz_params = [v for k, v in dict(gflownet.named_parameters()).items() if k != "logZ"]
optimizer = torch.optim.Adam(non_logz_params, lr=1e-3)
# Move the gflownet to the GPU.
if torch.cuda.is_available():
gflownet = gflownet.to("cuda")

# Log Z gets dedicated learning rate (typically higher).
logz_params = [dict(gflownet.named_parameters())["logZ"]]
optimizer.add_param_group({"params": logz_params, "lr": 1e-1})
# Policy parameters have their own LR. Log Z gets dedicated learning rate
# (typically higher).
optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": 1e-1})

n_iterations = int(1e4)
batch_size = int(1e5)
Expand Down

0 comments on commit 61c08be

Please sign in to comment.