From 7ea8f9dfbf0bca2997bb3262ec0c3bd86f6a209e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 16 Nov 2023 18:02:01 -0500 Subject: [PATCH] Fix: clone of policy_outputs into logits in get_logprobs must not be detached --- gflownet/envs/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index a14927c9d..62c06ecd0 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -546,7 +546,7 @@ def get_logprobs( """ device = policy_outputs.device ns_range = torch.arange(policy_outputs.shape[0]).to(device) - logits = policy_outputs.clone().detach() + logits = policy_outputs.clone() if mask is not None: logits[mask] = -torch.inf action_indices = (