From 4a54ab328c5dfd882db90411c0c1f7fbc4f19b38 Mon Sep 17 00:00:00 2001 From: Tobias Ringwald Date: Wed, 17 Jan 2024 02:25:21 +0000 Subject: [PATCH] =?UTF-8?q?Removed=20an=20internal=20assertion=20for=20the?= =?UTF-8?q?=20optional=20stable=20value=20and=20inste=E2=80=A6=20(#117414)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ad defaulted to the standard (=false). Fixes #117255. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117414 Approved by: https://github.com/ezyang --- aten/src/ATen/native/Sorting.cpp | 5 +---- test/test_sort_and_select.py | 7 +++++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 91f05e367fed2..5726220c7b954 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -72,9 +72,6 @@ TORCH_META_FUNC(topk) TORCH_META_FUNC2(sort, stable) (const Tensor& self, c10::optional stable, int64_t dim, bool descending) { - TORCH_INTERNAL_ASSERT( - stable.has_value(), - "sort(): c10::optional for stable has to have value."); maybe_wrap_dim(dim, self.dim()); // See issue: https://github.com/pytorch/pytorch/issues/65863 @@ -953,7 +950,7 @@ TORCH_IMPL_FUNC(sort_stable_out) indices.zero_(); } else { dim = maybe_wrap_dim(dim, self.dim()); - sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value()); + sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value_or(false)); } } diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 9439382b98582..7709131e61020 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -137,6 +137,13 @@ def test_sort(self, device): self.assertIsOrdered('descending', x, res2val, res2ind, 'random with NaNs') + def test_sort_stable_none(self): + # Called sort with stable=None used to trigger an assertion + # See https://github.com/pytorch/pytorch/issues/117255 + x = torch.ones(10) + y = x.sort(stable=None).values + self.assertTrue(torch.all(y == torch.ones(10)).item()) + @onlyCUDA def test_sort_large_slice(self, device): # tests direct cub path