Skip to content

Commit

Permalink
updated scripts with new API and tweaked tests (with reproducibility)
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 27, 2023
1 parent b0432c9 commit b987a39
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 31 deletions.
2 changes: 1 addition & 1 deletion tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_hypergrid(ndim: int, height: int):
args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories)
final_l1_dist = train_hypergrid_main(args)
if ndim == 2 and height == 8:
assert np.isclose(final_l1_dist, 9.14e-4, atol=1e-5)
assert np.isclose(final_l1_dist, 8.78e-4, atol=1e-5)
elif ndim == 2 and height == 16:
assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-5)
elif ndim == 4 and height == 8:
Expand Down
17 changes: 11 additions & 6 deletions tutorials/examples/train_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
BoxStateFlowModule,
)
from gfn.modules import ScalarEstimator
from gfn.utils.common import set_seed

DEFAULT_SEED = 4444

Expand Down Expand Up @@ -86,7 +87,7 @@ def estimate_jsd(kde1, kde2):

def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
torch.manual_seed(seed)
set_seed(seed)

device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

Expand Down Expand Up @@ -157,28 +158,28 @@ def main(args): # noqa: C901
pf=pf_estimator,
pb=pb_estimator,
logF=logF_estimator,
on_policy=True,
off_policy=False,
)
else:
gflownet = SubTBGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
logF=logF_estimator,
on_policy=True,
off_policy=False,
weighting=args.subTB_weighting,
lamda=args.subTB_lambda,
)
elif args.loss == "TB":
gflownet = TBGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
on_policy=True,
off_policy=False,
)
elif args.loss == "ZVar":
gflownet = LogPartitionVarianceGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
on_policy=True,
off_policy=False,
)

assert gflownet is not None, f"No gflownet for loss {args.loss}"
Expand Down Expand Up @@ -231,7 +232,11 @@ def main(args): # noqa: C901
if iteration % 1000 == 0:
print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}")

trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size)
trajectories = gflownet.sample_trajectories(
env,
sample_off_policy=False,
n_samples=args.batch_size
)

training_samples = gflownet.to_training_samples(trajectories)

Expand Down
10 changes: 8 additions & 2 deletions tutorials/examples/train_discreteebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from gfn.utils.common import validate
from gfn.utils.modules import NeuralNet, Tabular

from gfn.utils.common import set_seed

DEFAULT_SEED = 4444


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
torch.manual_seed(seed)
set_seed(seed)

device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

Expand Down Expand Up @@ -69,7 +71,11 @@ def main(args): # noqa: C901
n_iterations = args.n_trajectories // args.batch_size
validation_info = {"l1_dist": float("inf")}
for iteration in trange(n_iterations):
trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size)
trajectories = gflownet.sample_trajectories(
env,
off_policy=False,
n_samples=args.batch_size
)
training_samples = gflownet.to_training_samples(trajectories)

optimizer.zero_grad()
Expand Down
21 changes: 11 additions & 10 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
from gfn.utils.common import validate
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular

from gfn.utils.common import set_seed

DEFAULT_SEED = 4444


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
torch.manual_seed(seed)

set_seed(seed)
off_policy_sampling = False if args.replay_buffer_size == 0 else True
device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

use_wandb = len(args.wandb_project) > 0
Expand Down Expand Up @@ -122,7 +124,7 @@ def main(args): # noqa: C901
gflownet = ModifiedDBGFlowNet(
pf_estimator,
pb_estimator,
True if args.replay_buffer_size == 0 else False,
off_policy_sampling,
)

if args.loss in ("DB", "SubTB"):
Expand Down Expand Up @@ -153,34 +155,33 @@ def main(args): # noqa: C901
pf=pf_estimator,
pb=pb_estimator,
logF=logF_estimator,
on_policy=True if args.replay_buffer_size == 0 else False,
off_policy=off_policy_sampling,
)
else:
gflownet = SubTBGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
logF=logF_estimator,
on_policy=True if args.replay_buffer_size == 0 else False,
off_policy=off_policy_sampling,
weighting=args.subTB_weighting,
lamda=args.subTB_lambda,
)
elif args.loss == "TB":
gflownet = TBGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
on_policy=True if args.replay_buffer_size == 0 else False,
off_policy=off_policy_sampling,
)
elif args.loss == "ZVar":
gflownet = LogPartitionVarianceGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
on_policy=True if args.replay_buffer_size == 0 else False,
off_policy=off_policy_sampling,
)

assert gflownet is not None, f"No gflownet for loss {args.loss}"

# Initialize the replay buffer ?

replay_buffer = None
if args.replay_buffer_size > 0:
if args.loss in ("TB", "SubTB", "ZVar"):
Expand Down Expand Up @@ -224,7 +225,7 @@ def main(args): # noqa: C901
n_iterations = args.n_trajectories // args.batch_size
validation_info = {"l1_dist": float("inf")}
for iteration in trange(n_iterations):
trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size)
trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling)
training_samples = gflownet.to_training_samples(trajectories)
if replay_buffer is not None:
with torch.no_grad():
Expand Down Expand Up @@ -290,7 +291,7 @@ def main(args): # noqa: C901
parser.add_argument(
"--replay_buffer_size",
type=int,
default=0,
default=100,
help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.",
)

Expand Down
16 changes: 4 additions & 12 deletions tutorials/examples/train_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from gfn.states import States
from gfn.utils import NeuralNet

from gfn.utils.common import set_seed


class Line(Env):
"""Mixture of Gaussians Line environment."""
Expand Down Expand Up @@ -287,16 +289,6 @@ def to_probability_distribution(
n_steps=self.n_steps_per_trajectory,
)


def fix_seed(seed):
"""Reproducibility."""
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)


def train(
gflownet,
env,
Expand All @@ -308,7 +300,7 @@ def train(
exploration_var_starting_val=2,
):
"""Trains a GFlowNet on the Line Environment."""
fix_seed(seed)
set_seed(seed)
n_iterations = int(n_trajectories // batch_size)

# TODO: Add in the uniform pb demo?
Expand Down Expand Up @@ -400,7 +392,7 @@ def train(
policy_std_max=policy_std_max,
)
pb = StepEstimator(environment, pb_module, backward=True)
gflownet = TBGFlowNet(pf=pf, pb=pb, on_policy=False, init_logZ=0.0)
gflownet = TBGFlowNet(pf=pf, pb=pb, off_policy=False, init_logZ=0.0)

gflownet = train(
gflownet,
Expand Down

0 comments on commit b987a39

Please sign in to comment.