Skip to content

Commit

Permalink
Fix: clone of policy_outputs into logits in get_logprobs must not be …
Browse files Browse the repository at this point in the history
…detached
  • Loading branch information
alexhernandezgarcia committed Nov 16, 2023
1 parent 619281d commit 7ea8f9d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def get_logprobs(
"""
device = policy_outputs.device
ns_range = torch.arange(policy_outputs.shape[0]).to(device)
logits = policy_outputs.clone().detach()
logits = policy_outputs.clone()
if mask is not None:
logits[mask] = -torch.inf
action_indices = (
Expand Down

0 comments on commit 7ea8f9d

Please sign in to comment.