From 25ac995ea2253980cf939a75074190da5c93d73b Mon Sep 17 00:00:00 2001 From: "sanghyeok.choi" Date: Fri, 24 Jan 2025 23:43:09 +0900 Subject: [PATCH] black applied --- src/gfn/containers/trajectories.py | 24 +++++++++++++++--------- src/gfn/samplers.py | 22 ++++++++++------------ tutorials/examples/train_box.py | 6 ++++-- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 68bfc5a..fe6438e 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -458,7 +458,9 @@ def reverse_backward_trajectories( # Initialize new actions and states new_actions = trajectories.env.dummy_action.repeat( max_len + 1, len(trajectories), 1 - ).to(actions) # shape (max_len + 1, n_trajectories, *action_dim) + ).to( + actions + ) # shape (max_len + 1, n_trajectories, *action_dim) new_states = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to( states ) # shape (max_len + 2, n_trajectories, *state_dim) @@ -492,9 +494,9 @@ def reverse_backward_trajectories( # Assign reversed actions to new_actions new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]] - new_actions[torch.arange(len(trajectories)), seq_lengths] = ( - trajectories.env.exit_action - ) + new_actions[ + torch.arange(len(trajectories)), seq_lengths + ] = trajectories.env.exit_action # Assign reversed states to new_states assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0" @@ -529,15 +531,19 @@ def reverse_backward_trajectories( if debug: _new_actions = trajectories.env.dummy_action.repeat( max_len + 1, len(trajectories), 1 - ).to(actions) # shape (max_len + 1, n_trajectories, *action_dim) + ).to( + actions + ) # shape (max_len + 1, n_trajectories, *action_dim) _new_states = trajectories.env.sf.repeat( max_len + 2, len(trajectories), 1 - ).to(states) # shape (max_len + 2, n_trajectories, *state_dim) + ).to( + states + ) # shape (max_len + 2, n_trajectories, *state_dim) for i in range(len(trajectories)): - _new_actions[trajectories.when_is_done[i], i] = ( - trajectories.env.exit_action - ) + _new_actions[ + trajectories.when_is_done[i], i + ] = trajectories.env.exit_action _new_actions[ : trajectories.when_is_done[i], i ] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip( diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index b170180..f5e359b 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -554,8 +554,6 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 bs = prev_trajectories.n_trajectories device = prev_trajectories.states.device - state_shape = prev_trajectories.states.state_shape - action_shape = prev_trajectories.env.action_shape env = prev_trajectories.env # Obtain full trajectories by concatenating the backward and forward parts. @@ -590,12 +588,12 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # Prepare the new states and actions # Note that these are initialized in transposed shapes - new_trajectories_states_tsr = prev_trajectories.env.sf.repeat( - bs, max_traj_len + 1, 1 - ).to(prev_trajectories.states.tensor) - new_trajectories_actions_tsr = prev_trajectories.env.dummy_action.repeat( - bs, max_traj_len, 1 - ).to(prev_trajectories.actions.tensor) + new_trajectories_states_tsr = env.sf.repeat(bs, max_traj_len + 1, 1).to( + prev_trajectories.states.tensor + ) + new_trajectories_actions_tsr = env.dummy_action.repeat(bs, max_traj_len, 1).to( + prev_trajectories.actions.tensor + ) # Assign the first part (backtracked from backward policy) of the trajectory prev_mask_truc = prev_mask[:, :max_n_prev] @@ -664,10 +662,10 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # If `debug` is True (expected only when testing), compare the # vectorized approach's results (above) to the for-loop results (below). if debug: - _new_trajectories_states_tsr = prev_trajectories.env.sf.repeat( - max_traj_len + 1, bs, 1 - ).to(prev_trajectories.states.tensor) - _new_trajectories_actions_tsr = prev_trajectories.env.dummy_action.repeat( + _new_trajectories_states_tsr = env.sf.repeat(max_traj_len + 1, bs, 1).to( + prev_trajectories.states.tensor + ) + _new_trajectories_actions_tsr = env.dummy_action.repeat( max_traj_len, bs, 1 ).to(prev_trajectories.actions.tensor) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 65b2583..dfcef4a 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -32,7 +32,7 @@ BoxStateFlowModule, ) from gfn.modules import ScalarEstimator -from gfn.samplers import Sampler, LocalSearchSampler +from gfn.samplers import LocalSearchSampler, Sampler from gfn.utils.common import set_seed DEFAULT_SEED = 4444 @@ -186,7 +186,9 @@ def main(args): # noqa: C901 sampler = Sampler(estimator=pf_estimator) local_search_params = {} else: - sampler = LocalSearchSampler(pf_estimator=pf_estimator, pb_estimator=pb_estimator) + sampler = LocalSearchSampler( + pf_estimator=pf_estimator, pb_estimator=pb_estimator + ) local_search_params = { "n_local_search_loops": args.n_local_search_loops, "back_ratio": args.back_ratio,