Skip to content

Commit

Permalink
Removed an internal assertion for the optional stable value and inste… (
Browse files Browse the repository at this point in the history
pytorch#117414)

…ad defaulted to the standard (=false).

Fixes pytorch#117255.
Pull Request resolved: pytorch#117414
Approved by: https://github.com/ezyang
  • Loading branch information
tringwald authored and pytorchmergebot committed Jan 17, 2024
1 parent 1872834 commit 4a54ab3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
5 changes: 1 addition & 4 deletions aten/src/ATen/native/Sorting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ TORCH_META_FUNC(topk)

TORCH_META_FUNC2(sort, stable)
(const Tensor& self, c10::optional<bool> stable, int64_t dim, bool descending) {
TORCH_INTERNAL_ASSERT(
stable.has_value(),
"sort(): c10::optional<bool> for stable has to have value.");
maybe_wrap_dim(dim, self.dim());

// See issue: https://github.com/pytorch/pytorch/issues/65863
Expand Down Expand Up @@ -953,7 +950,7 @@ TORCH_IMPL_FUNC(sort_stable_out)
indices.zero_();
} else {
dim = maybe_wrap_dim(dim, self.dim());
sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value());
sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value_or(false));
}
}

Expand Down
7 changes: 7 additions & 0 deletions test/test_sort_and_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ def test_sort(self, device):
self.assertIsOrdered('descending', x, res2val, res2ind,
'random with NaNs')

def test_sort_stable_none(self):
# Called sort with stable=None used to trigger an assertion
# See https://github.com/pytorch/pytorch/issues/117255
x = torch.ones(10)
y = x.sort(stable=None).values
self.assertTrue(torch.all(y == torch.ones(10)).item())

@onlyCUDA
def test_sort_large_slice(self, device):
# tests direct cub path
Expand Down

0 comments on commit 4a54ab3

Please sign in to comment.