From 731e081ec99ef3ab819f253f8ddcb097caedf7d8 Mon Sep 17 00:00:00 2001 From: Salem Lahlou Date: Sun, 12 Jan 2025 01:23:32 +0400 Subject: [PATCH] make pyright happy --- src/gfn/samplers.py | 64 +++++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 566bfed6..750af7cf 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -610,9 +610,9 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 new_trajectories_states_tsr[state_recon_mask] = recon_trajectories_states_tsr[ state_recon_mask2 ] - new_trajectories_actions_tsr[action_recon_mask] = ( - recon_trajectories_actions_tsr[action_recon_mask2] - ) + new_trajectories_actions_tsr[ + action_recon_mask + ] = recon_trajectories_actions_tsr[action_recon_mask2] # Transpose back new_trajectories_states_tsr = new_trajectories_states_tsr.transpose(0, 1) @@ -620,8 +620,16 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 # Similarly, combine log_pf and log_pb if needed if save_logprobs: - prev_trajectories_log_pf = prev_trajectories_log_pf.transpose(0, 1) - recon_trajectories_log_pf = recon_trajectories_log_pf.transpose(0, 1) + prev_trajectories_log_pf = ( + prev_trajectories_log_pf.transpose(0, 1) + if prev_trajectories_log_pf is not None + else None + ) + recon_trajectories_log_pf = ( + recon_trajectories_log_pf.transpose(0, 1) + if recon_trajectories_log_pf is not None + else None + ) new_trajectories_log_pf = torch.full((bs, max_traj_len), 0.0).to( device=device, dtype=torch.float ) @@ -682,36 +690,40 @@ def _combine_prev_and_recon_trajectories( # noqa: C901 _n_prev = n_prevs[i] # Backward part - _new_trajectories_states_tsr[: _n_prev + 1, i] = ( - prev_trajectories.states.tensor[: _n_prev + 1, i] - ) - _new_trajectories_actions_tsr[:_n_prev, i] = ( - prev_trajectories.actions.tensor[:_n_prev, i] - ) + _new_trajectories_states_tsr[ + : _n_prev + 1, i + ] = prev_trajectories.states.tensor[: _n_prev + 1, i] + _new_trajectories_actions_tsr[ + :_n_prev, i + ] = prev_trajectories.actions.tensor[:_n_prev, i] # Forward part _len_recon = recon_trajectories.when_is_done[i] _new_trajectories_states_tsr[ _n_prev + 1 : _n_prev + _len_recon + 1, i ] = recon_trajectories.states.tensor[1 : _len_recon + 1, i] - _new_trajectories_actions_tsr[_n_prev : _n_prev + _len_recon, i] = ( - recon_trajectories.actions.tensor[:_len_recon, i] - ) + _new_trajectories_actions_tsr[ + _n_prev : _n_prev + _len_recon, i + ] = recon_trajectories.actions.tensor[:_len_recon, i] if save_logprobs: - _new_trajectories_log_pf[:_n_prev, i] = prev_trajectories_log_pf[ # type: ignore - :_n_prev, i - ] - _new_trajectories_log_pf[_n_prev : _n_prev + _len_recon, i] = ( - recon_trajectories_log_pf[:_len_recon, i] # type: ignore - ) + if prev_trajectories_log_pf is not None: # Add this check + _new_trajectories_log_pf[ + :_n_prev, i + ] = prev_trajectories_log_pf[:_n_prev, i] + if recon_trajectories_log_pf is not None: # Add this check + _new_trajectories_log_pf[ + _n_prev : _n_prev + _len_recon, i + ] = recon_trajectories_log_pf[:_len_recon, i] if use_metropolis_hastings: - _new_trajectories_log_pb[:_n_prev, i] = prev_trajectories_log_pb[ # type: ignore - :_n_prev, i - ] - _new_trajectories_log_pb[_n_prev : _n_prev + _len_recon, i] = ( - recon_trajectories_log_pb[:_len_recon, i] # type: ignore - ) + if prev_trajectories_log_pb is not None: # Add this check + _new_trajectories_log_pb[ + :_n_prev, i + ] = prev_trajectories_log_pb[:_n_prev, i] + if recon_trajectories_log_pb is not None: # Add this check + _new_trajectories_log_pb[ + _n_prev : _n_prev + _len_recon, i + ] = recon_trajectories_log_pb[:_len_recon, i] assert torch.all( _new_trajectories_states_tsr == new_trajectories_states_tsr