Skip to content

Commit

Permalink
make pyright happy
Browse files Browse the repository at this point in the history
  • Loading branch information
Salem Lahlou committed Jan 11, 2025
1 parent 26126df commit 731e081
Showing 1 changed file with 38 additions and 26 deletions.
64 changes: 38 additions & 26 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,18 +610,26 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
new_trajectories_states_tsr[state_recon_mask] = recon_trajectories_states_tsr[
state_recon_mask2
]
new_trajectories_actions_tsr[action_recon_mask] = (
recon_trajectories_actions_tsr[action_recon_mask2]
)
new_trajectories_actions_tsr[
action_recon_mask
] = recon_trajectories_actions_tsr[action_recon_mask2]

# Transpose back
new_trajectories_states_tsr = new_trajectories_states_tsr.transpose(0, 1)
new_trajectories_actions_tsr = new_trajectories_actions_tsr.transpose(0, 1)

# Similarly, combine log_pf and log_pb if needed
if save_logprobs:
prev_trajectories_log_pf = prev_trajectories_log_pf.transpose(0, 1)
recon_trajectories_log_pf = recon_trajectories_log_pf.transpose(0, 1)
prev_trajectories_log_pf = (
prev_trajectories_log_pf.transpose(0, 1)
if prev_trajectories_log_pf is not None
else None
)
recon_trajectories_log_pf = (
recon_trajectories_log_pf.transpose(0, 1)
if recon_trajectories_log_pf is not None
else None
)
new_trajectories_log_pf = torch.full((bs, max_traj_len), 0.0).to(
device=device, dtype=torch.float
)
Expand Down Expand Up @@ -682,36 +690,40 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
_n_prev = n_prevs[i]

# Backward part
_new_trajectories_states_tsr[: _n_prev + 1, i] = (
prev_trajectories.states.tensor[: _n_prev + 1, i]
)
_new_trajectories_actions_tsr[:_n_prev, i] = (
prev_trajectories.actions.tensor[:_n_prev, i]
)
_new_trajectories_states_tsr[
: _n_prev + 1, i
] = prev_trajectories.states.tensor[: _n_prev + 1, i]
_new_trajectories_actions_tsr[
:_n_prev, i
] = prev_trajectories.actions.tensor[:_n_prev, i]

# Forward part
_len_recon = recon_trajectories.when_is_done[i]
_new_trajectories_states_tsr[
_n_prev + 1 : _n_prev + _len_recon + 1, i
] = recon_trajectories.states.tensor[1 : _len_recon + 1, i]
_new_trajectories_actions_tsr[_n_prev : _n_prev + _len_recon, i] = (
recon_trajectories.actions.tensor[:_len_recon, i]
)
_new_trajectories_actions_tsr[
_n_prev : _n_prev + _len_recon, i
] = recon_trajectories.actions.tensor[:_len_recon, i]

if save_logprobs:
_new_trajectories_log_pf[:_n_prev, i] = prev_trajectories_log_pf[ # type: ignore
:_n_prev, i
]
_new_trajectories_log_pf[_n_prev : _n_prev + _len_recon, i] = (
recon_trajectories_log_pf[:_len_recon, i] # type: ignore
)
if prev_trajectories_log_pf is not None: # Add this check
_new_trajectories_log_pf[
:_n_prev, i
] = prev_trajectories_log_pf[:_n_prev, i]
if recon_trajectories_log_pf is not None: # Add this check
_new_trajectories_log_pf[
_n_prev : _n_prev + _len_recon, i
] = recon_trajectories_log_pf[:_len_recon, i]
if use_metropolis_hastings:
_new_trajectories_log_pb[:_n_prev, i] = prev_trajectories_log_pb[ # type: ignore
:_n_prev, i
]
_new_trajectories_log_pb[_n_prev : _n_prev + _len_recon, i] = (
recon_trajectories_log_pb[:_len_recon, i] # type: ignore
)
if prev_trajectories_log_pb is not None: # Add this check
_new_trajectories_log_pb[
:_n_prev, i
] = prev_trajectories_log_pb[:_n_prev, i]
if recon_trajectories_log_pb is not None: # Add this check
_new_trajectories_log_pb[
_n_prev : _n_prev + _len_recon, i
] = recon_trajectories_log_pb[:_len_recon, i]

assert torch.all(
_new_trajectories_states_tsr == new_trajectories_states_tsr
Expand Down

0 comments on commit 731e081

Please sign in to comment.