From fc877d2f2cac2a4e938866cb2ec14c205701033d Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Fri, 1 Mar 2024 16:32:23 -0500 Subject: [PATCH] device bug fix --- gflownet/utils/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 697281976..dfc2919c2 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -298,7 +298,9 @@ def bootstrap_samples(tensor, num_samples): returns tensor of the shape [initial_shape, num_samples] """ dim_size = tensor.size(-1) - bs_indices = torch.randint(0, dim_size, size=(num_samples * dim_size,)) + bs_indices = torch.randint( + 0, dim_size, size=(num_samples * dim_size,), device=tensor.device + ) bs_samples = torch.index_select(tensor, -1, index=bs_indices) bs_samples = bs_samples.view( tensor.size()[:-1] + (num_samples, dim_size)