Skip to content

Commit

Permalink
Merge pull request #607 from soraros/node-dtype
Browse files Browse the repository at this point in the history
Attach dtype to nodes
  • Loading branch information
gertjanvanzwieten authored Jan 19, 2021
2 parents 14140e8 + 5f1074a commit 27dfc4b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
9 changes: 5 additions & 4 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,8 @@ def __iter__(self):
__pow__ = power
__abs__ = lambda self: abs(self)
__mod__ = lambda self, other: mod(self, other)
__str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, ','.join(map(str, self.shape)) if hasattr(self, 'shape') else '?')
__str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, self._shape_str(form=str))
_shape_str = lambda self, form: '{}:{}'.format(self.dtype.__name__[0] if hasattr(self, 'dtype') else '?', ','.join(map(form, self._axes)) if hasattr(self, '_axes') else '?')

sum = sum
prod = product
Expand Down Expand Up @@ -923,7 +924,7 @@ def _node(self, cache, subgraph, times):
if self in cache:
return cache[self]
args = tuple(arg._node(cache, subgraph, times) for arg in self._Evaluable__args)
label = '\n'.join(filter(None, (type(self).__name__, self._node_details, ','.join(map(repr, self._axes)))))
label = '\n'.join(filter(None, (type(self).__name__, self._node_details, self._shape_str(form=repr))))
cache[self] = node = RegularNode(label, args, {}, (type(self).__name__, times[self]), subgraph)
return node

Expand Down Expand Up @@ -2855,13 +2856,13 @@ def _derivative(self, var, seen):
return zeros(self.shape+var.shape)

def __str__(self):
return '{} {!r} <{}>'.format(self.__class__.__name__, self._name, ','.join(map(str, self.shape)))
return '{} {!r} <{}>'.format(self.__class__.__name__, self._name, self._shape_str(form=str))

def _node(self, cache, subgraph, times):
if self in cache:
return cache[self]
else:
label = '\n'.join(filter(None, (type(self).__name__, self._name, ','.join(map(repr, self._axes)))))
label = '\n'.join(filter(None, (type(self).__name__, self._name, self._shape_str(form=repr))))
cache[self] = node = DuplicatedLeafNode(label, (type(self).__name__, times[self]))
return node

Expand Down
2 changes: 2 additions & 0 deletions nutils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def setUp(self):
to ignore warnings for the entire class.
'''

maxDiff = None # prevent assertEqual from shortening the diff error message

def enter_context(self, ctx):
retval = ctx.__enter__()
self.addCleanup(ctx.__exit__, None, None, None)
Expand Down
42 changes: 23 additions & 19 deletions tests/test_evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def assertFunctionAlmostEqual(self, actual, desired, decimal):
numpy.add.at(dense, indices, values)
self.assertArrayAlmostEqual(dense, desired, decimal)

def test_str(self):
a = evaluable.Array((), shape=(2, 3), dtype=float)
self.assertEqual(str(a),'nutils.evaluable.Array<f:2,3>')

def test_evalconst(self):
constargs = [numpy.random.uniform(size=shape) for shape in self.shapes]
self.assertFunctionAlmostEqual(decimal=15,
Expand Down Expand Up @@ -681,17 +685,17 @@ class asciitree(TestCase):
def test_asciitree(self):
f = evaluable.Sin(evaluable.Inflate(1, evaluable.Zeros((), int), 2)**evaluable.Diagonalize(evaluable.Argument('arg', (2,))))
self.assertEqual(f.asciitree(richoutput=True),
'%0 = Sin; a2,a2\n'
'└ %1 = Power; a2,a2\n'
' ├ %2 = Transpose; 1,0; i2,s2\n'
' │ └ %3 = InsertAxis; s2,i2\n'
' │ ├ %4 = Inflate; s2\n'
'%0 = Sin; f:a2,a2\n'
'└ %1 = Power; f:a2,a2\n'
' ├ %2 = Transpose; 1,0; i:i2,s2\n'
' │ └ %3 = InsertAxis; i:s2,i2\n'
' │ ├ %4 = Inflate; i:s2\n'
' │ │ ├ 1\n'
' │ │ ├ 0\n'
' │ │ └ 2\n'
' │ └ 2\n'
' └ %5 = Diagonalize; d2,d2\n'
' └ Argument; arg; a2\n')
' └ %5 = Diagonalize; f:d2,d2\n'
' └ Argument; arg; f:a2\n')

@unittest.skipIf(sys.version_info < (3, 6), 'test requires dicts maintaining insertion order')
def test_loop_sum(self):
Expand All @@ -717,19 +721,19 @@ def test_loop_concatenate(self):
'NODES\n'
'%B0 = LoopConcatenate\n'
'├ shape[0] = 2\n'
'├ start = %B1 = Take\n'
'│ ├ %A2 = _SizesToOffsets; a3\n'
'│ │ └ %A3 = InsertAxis; i2\n'
'├ start = %B1 = Take; i:\n'
'│ ├ %A2 = _SizesToOffsets; i:a3\n'
'│ │ └ %A3 = InsertAxis; i:i2\n'
'│ │ ├ 1\n'
'│ │ └ 2\n'
'│ └ %B4 = LoopIndex\n'
'│ └ length = 2\n'
'├ stop = %B5 = Take\n'
'├ stop = %B5 = Take; i:\n'
'│ ├ %A2\n'
'│ └ %B6 = Add\n'
'│ └ %B6 = Add; i:\n'
'│ ├ %B4\n'
'│ └ 1\n'
'└ func = %B7 = InsertAxis; i1\n'
'└ func = %B7 = InsertAxis; i:i1\n'
' ├ %B4\n'
' └ 1\n')

Expand All @@ -744,19 +748,19 @@ def test_loop_concatenatecombined(self):
'NODES\n'
'%B0 = LoopConcatenate\n'
'├ shape[0] = 2\n'
'├ start = %B1 = Take\n'
'│ ├ %A2 = _SizesToOffsets; a3\n'
'│ │ └ %A3 = InsertAxis; i2\n'
'├ start = %B1 = Take; i:\n'
'│ ├ %A2 = _SizesToOffsets; i:a3\n'
'│ │ └ %A3 = InsertAxis; i:i2\n'
'│ │ ├ 1\n'
'│ │ └ 2\n'
'│ └ %B4 = LoopIndex\n'
'│ └ length = 2\n'
'├ stop = %B5 = Take\n'
'├ stop = %B5 = Take; i:\n'
'│ ├ %A2\n'
'│ └ %B6 = Add\n'
'│ └ %B6 = Add; i:\n'
'│ ├ %B4\n'
'│ └ 1\n'
'└ func = %B7 = InsertAxis; i1\n'
'└ func = %B7 = InsertAxis; i:i1\n'
' ├ %B4\n'
' └ 1\n')

Expand Down

0 comments on commit 27dfc4b

Please sign in to comment.