Skip to content

Commit

Permalink
add optional mask argument to sparse.prune
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Apr 27, 2021
1 parent 8a3c8b8 commit 69b9555
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
19 changes: 11 additions & 8 deletions nutils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,14 @@ def dedup(data, inplace=False):
numpy.add.at(dedup['value'], offsets, data['value'][1:])
return dedup

def prune(data, inplace=False):
def prune(data, inplace=False, mask=None):
'''Prune zero values.
Prune returns a sparse object with all zero values removed. If ``inplace`` is
true the returned object reuses the input array's memory. This may affect the
size of the array, which should no longer be used after pruning in place. In
case the input has no zeros the input array is returned.
Prune returns a sparse object with all zero values removed, or all entries
for which the boolean vector ``mask`` is true if it is specified. If
``inplace`` is true the returned object reuses the input array's memory. This
may affect the size of the array, which should no longer be used after
pruning in place. In case the input has no zeros the input array is returned.
>>> from nutils.sparse import dtype, prune
>>> from numpy import array
Expand All @@ -172,11 +173,13 @@ def prune(data, inplace=False):
dtype=[('index', [((2, 'i0'), 'u1'), ((2, 'i1'), 'u1')]), ('value', '<f8')])
'''

if data['value'].all():
if mask is None:
mask = data['value']
if mask.all():
return data
elif inplace:
buf = numpy.empty(chunksize // data.dtype.itemsize or 1, dtype=data.dtype)
nz, = data['value'].nonzero()
nz, = mask.nonzero()
for i in range(0, len(nz), len(buf)):
s = nz[i:i+len(buf)]
overlap = i+len(s) > s[0]
Expand All @@ -186,7 +189,7 @@ def prune(data, inplace=False):
data[i:i+len(s)] = chunk
return _resize(data, len(nz))
else:
return numpy.compress(data['value'], data)
return numpy.compress(mask, data)

def add(datas):
'''Add sparse objects.
Expand Down
18 changes: 18 additions & 0 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def test_prune(self):
self.assertEqual(prune.tolist(),
[((4,),10), ((4,),20), ((3,),1), ((2,),30), ((1,),40), ((2,),50), ((3,),-1), ((0,),60)])

def test_prune_mask(self):
mask = (numpy.arange(9) % 2).astype(bool)
for inplace in False, True:
with self.subTest(inplace=inplace), chunksize(self.data.itemsize * 3):
prune = sparse.prune(self.data, inplace=inplace, mask=mask)
(self.assertIs if inplace else self.assertIsNot)(prune, self.data)
self.assertEqual(prune.tolist(),
[((4,),20), ((2,),30), ((2,),50), ((0,),0)])

def test_block(self):
A = self.data
B = C = numpy.array([
Expand Down Expand Up @@ -218,6 +227,15 @@ def test_prune(self):
self.assertEqual(prune.tolist(),
[((2,4),10), ((3,4),20), ((2,3),1), ((1,2),30), ((0,1),40), ((1,2),50), ((2,3),-1), ((2,0),60)])

def test_prune_mask(self):
mask = (numpy.arange(8) % 2).astype(bool)
for inplace in False, True:
with self.subTest(inplace=inplace), chunksize(self.data.itemsize * 3):
prune = sparse.prune(self.data, inplace=inplace, mask=mask)
(self.assertIs if inplace else self.assertIsNot)(prune, self.data)
self.assertEqual(prune.tolist(),
[((3,4),20), ((1,2),30), ((1,2),50), ((3,0),0)])

def test_block(self):
A = self.data
B = numpy.array([((1,0), 10)], dtype=sparse.dtype([4,2]))
Expand Down

0 comments on commit 69b9555

Please sign in to comment.