Skip to content

Commit

Permalink
Merge pull request #296 from alexhernandezgarcia/tiny_bug_bootstrap
Browse files Browse the repository at this point in the history
Tiny bug fix in bootstrap
  • Loading branch information
alexhernandezgarcia authored Mar 13, 2024
2 parents 24640af + fc877d2 commit 030f207
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion gflownet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@ def bootstrap_samples(tensor, num_samples):
returns tensor of the shape [initial_shape, num_samples]
"""
dim_size = tensor.size(-1)
bs_indices = torch.randint(0, dim_size, size=(num_samples * dim_size,))
bs_indices = torch.randint(
0, dim_size, size=(num_samples * dim_size,), device=tensor.device
)
bs_samples = torch.index_select(tensor, -1, index=bs_indices)
bs_samples = bs_samples.view(
tensor.size()[:-1] + (num_samples, dim_size)
Expand Down

0 comments on commit 030f207

Please sign in to comment.