diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 697281976..dfc2919c2 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -298,7 +298,9 @@ def bootstrap_samples(tensor, num_samples): returns tensor of the shape [initial_shape, num_samples] """ dim_size = tensor.size(-1) - bs_indices = torch.randint(0, dim_size, size=(num_samples * dim_size,)) + bs_indices = torch.randint( + 0, dim_size, size=(num_samples * dim_size,), device=tensor.device + ) bs_samples = torch.index_select(tensor, -1, index=bs_indices) bs_samples = bs_samples.view( tensor.size()[:-1] + (num_samples, dim_size)