Skip to content

Commit

Permalink
black applied
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Jan 24, 2025
1 parent 89ffe04 commit 25ac995
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 23 deletions.
24 changes: 15 additions & 9 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,9 @@ def reverse_backward_trajectories(
# Initialize new actions and states
new_actions = trajectories.env.dummy_action.repeat(
max_len + 1, len(trajectories), 1
).to(actions) # shape (max_len + 1, n_trajectories, *action_dim)
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
new_states = trajectories.env.sf.repeat(max_len + 2, len(trajectories), 1).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)
Expand Down Expand Up @@ -492,9 +494,9 @@ def reverse_backward_trajectories(

# Assign reversed actions to new_actions
new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]]
new_actions[torch.arange(len(trajectories)), seq_lengths] = (
trajectories.env.exit_action
)
new_actions[
torch.arange(len(trajectories)), seq_lengths
] = trajectories.env.exit_action

# Assign reversed states to new_states
assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0"
Expand Down Expand Up @@ -529,15 +531,19 @@ def reverse_backward_trajectories(
if debug:
_new_actions = trajectories.env.dummy_action.repeat(
max_len + 1, len(trajectories), 1
).to(actions) # shape (max_len + 1, n_trajectories, *action_dim)
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
_new_states = trajectories.env.sf.repeat(
max_len + 2, len(trajectories), 1
).to(states) # shape (max_len + 2, n_trajectories, *state_dim)
).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)

for i in range(len(trajectories)):
_new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.exit_action
)
_new_actions[
trajectories.when_is_done[i], i
] = trajectories.env.exit_action
_new_actions[
: trajectories.when_is_done[i], i
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(
Expand Down
22 changes: 10 additions & 12 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,6 @@ def _combine_prev_and_recon_trajectories( # noqa: C901

bs = prev_trajectories.n_trajectories
device = prev_trajectories.states.device
state_shape = prev_trajectories.states.state_shape
action_shape = prev_trajectories.env.action_shape
env = prev_trajectories.env

# Obtain full trajectories by concatenating the backward and forward parts.
Expand Down Expand Up @@ -590,12 +588,12 @@ def _combine_prev_and_recon_trajectories( # noqa: C901

# Prepare the new states and actions
# Note that these are initialized in transposed shapes
new_trajectories_states_tsr = prev_trajectories.env.sf.repeat(
bs, max_traj_len + 1, 1
).to(prev_trajectories.states.tensor)
new_trajectories_actions_tsr = prev_trajectories.env.dummy_action.repeat(
bs, max_traj_len, 1
).to(prev_trajectories.actions.tensor)
new_trajectories_states_tsr = env.sf.repeat(bs, max_traj_len + 1, 1).to(
prev_trajectories.states.tensor
)
new_trajectories_actions_tsr = env.dummy_action.repeat(bs, max_traj_len, 1).to(
prev_trajectories.actions.tensor
)

# Assign the first part (backtracked from backward policy) of the trajectory
prev_mask_truc = prev_mask[:, :max_n_prev]
Expand Down Expand Up @@ -664,10 +662,10 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
# If `debug` is True (expected only when testing), compare the
# vectorized approach's results (above) to the for-loop results (below).
if debug:
_new_trajectories_states_tsr = prev_trajectories.env.sf.repeat(
max_traj_len + 1, bs, 1
).to(prev_trajectories.states.tensor)
_new_trajectories_actions_tsr = prev_trajectories.env.dummy_action.repeat(
_new_trajectories_states_tsr = env.sf.repeat(max_traj_len + 1, bs, 1).to(
prev_trajectories.states.tensor
)
_new_trajectories_actions_tsr = env.dummy_action.repeat(
max_traj_len, bs, 1
).to(prev_trajectories.actions.tensor)

Expand Down
6 changes: 4 additions & 2 deletions tutorials/examples/train_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
BoxStateFlowModule,
)
from gfn.modules import ScalarEstimator
from gfn.samplers import Sampler, LocalSearchSampler
from gfn.samplers import LocalSearchSampler, Sampler
from gfn.utils.common import set_seed

DEFAULT_SEED = 4444
Expand Down Expand Up @@ -186,7 +186,9 @@ def main(args): # noqa: C901
sampler = Sampler(estimator=pf_estimator)
local_search_params = {}
else:
sampler = LocalSearchSampler(pf_estimator=pf_estimator, pb_estimator=pb_estimator)
sampler = LocalSearchSampler(
pf_estimator=pf_estimator, pb_estimator=pb_estimator
)
local_search_params = {
"n_local_search_loops": args.n_local_search_loops,
"back_ratio": args.back_ratio,
Expand Down

0 comments on commit 25ac995

Please sign in to comment.