diff --git a/nutils/sparse.py b/nutils/sparse.py index d20224e0b..f4c85c36a 100644 --- a/nutils/sparse.py +++ b/nutils/sparse.py @@ -156,13 +156,14 @@ def dedup(data, inplace=False): numpy.add.at(dedup['value'], offsets, data['value'][1:]) return dedup -def prune(data, inplace=False): +def prune(data, inplace=False, mask=None): '''Prune zero values. - Prune returns a sparse object with all zero values removed. If ``inplace`` is - true the returned object reuses the input array's memory. This may affect the - size of the array, which should no longer be used after pruning in place. In - case the input has no zeros the input array is returned. + Prune returns a sparse object with all zero values removed, or all entries + for which the boolean vector ``mask`` is true if it is specified. If + ``inplace`` is true the returned object reuses the input array's memory. This + may affect the size of the array, which should no longer be used after + pruning in place. In case the input has no zeros the input array is returned. >>> from nutils.sparse import dtype, prune >>> from numpy import array @@ -172,11 +173,13 @@ def prune(data, inplace=False): dtype=[('index', [((2, 'i0'), 'u1'), ((2, 'i1'), 'u1')]), ('value', ' s[0] @@ -186,7 +189,7 @@ def prune(data, inplace=False): data[i:i+len(s)] = chunk return _resize(data, len(nz)) else: - return numpy.compress(data['value'], data) + return numpy.compress(mask, data) def add(datas): '''Add sparse objects. diff --git a/tests/test_sparse.py b/tests/test_sparse.py index 77d794375..5d362b6f9 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -101,6 +101,15 @@ def test_prune(self): self.assertEqual(prune.tolist(), [((4,),10), ((4,),20), ((3,),1), ((2,),30), ((1,),40), ((2,),50), ((3,),-1), ((0,),60)]) + def test_prune_mask(self): + mask = (numpy.arange(9) % 2).astype(bool) + for inplace in False, True: + with self.subTest(inplace=inplace), chunksize(self.data.itemsize * 3): + prune = sparse.prune(self.data, inplace=inplace, mask=mask) + (self.assertIs if inplace else self.assertIsNot)(prune, self.data) + self.assertEqual(prune.tolist(), + [((4,),20), ((2,),30), ((2,),50), ((0,),0)]) + def test_block(self): A = self.data B = C = numpy.array([ @@ -218,6 +227,15 @@ def test_prune(self): self.assertEqual(prune.tolist(), [((2,4),10), ((3,4),20), ((2,3),1), ((1,2),30), ((0,1),40), ((1,2),50), ((2,3),-1), ((2,0),60)]) + def test_prune_mask(self): + mask = (numpy.arange(8) % 2).astype(bool) + for inplace in False, True: + with self.subTest(inplace=inplace), chunksize(self.data.itemsize * 3): + prune = sparse.prune(self.data, inplace=inplace, mask=mask) + (self.assertIs if inplace else self.assertIsNot)(prune, self.data) + self.assertEqual(prune.tolist(), + [((3,4),20), ((1,2),30), ((1,2),50), ((3,0),0)]) + def test_block(self): A = self.data B = numpy.array([((1,0), 10)], dtype=sparse.dtype([4,2]))