From 5f1074af3ede9c5bd77c69d9a957a09a55b788ab Mon Sep 17 00:00:00 2001 From: sora <210at85@gmail.com> Date: Fri, 8 Jan 2021 18:44:46 +0100 Subject: [PATCH] attach dtype to nodes --- nutils/evaluable.py | 9 +++++---- nutils/testing.py | 2 ++ tests/test_evaluable.py | 42 ++++++++++++++++++++++------------------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/nutils/evaluable.py b/nutils/evaluable.py index f5be14571..173876d63 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -858,7 +858,8 @@ def __iter__(self): __pow__ = power __abs__ = lambda self: abs(self) __mod__ = lambda self, other: mod(self, other) - __str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, ','.join(map(str, self.shape)) if hasattr(self, 'shape') else '?') + __str__ = __repr__ = lambda self: '{}.{}<{}>'.format(type(self).__module__, type(self).__name__, self._shape_str(form=str)) + _shape_str = lambda self, form: '{}:{}'.format(self.dtype.__name__[0] if hasattr(self, 'dtype') else '?', ','.join(map(form, self._axes)) if hasattr(self, '_axes') else '?') sum = sum prod = product @@ -923,7 +924,7 @@ def _node(self, cache, subgraph, times): if self in cache: return cache[self] args = tuple(arg._node(cache, subgraph, times) for arg in self._Evaluable__args) - label = '\n'.join(filter(None, (type(self).__name__, self._node_details, ','.join(map(repr, self._axes))))) + label = '\n'.join(filter(None, (type(self).__name__, self._node_details, self._shape_str(form=repr)))) cache[self] = node = RegularNode(label, args, {}, (type(self).__name__, times[self]), subgraph) return node @@ -2855,13 +2856,13 @@ def _derivative(self, var, seen): return zeros(self.shape+var.shape) def __str__(self): - return '{} {!r} <{}>'.format(self.__class__.__name__, self._name, ','.join(map(str, self.shape))) + return '{} {!r} <{}>'.format(self.__class__.__name__, self._name, self._shape_str(form=str)) def _node(self, cache, subgraph, times): if self in cache: return cache[self] else: - label = '\n'.join(filter(None, (type(self).__name__, self._name, ','.join(map(repr, self._axes))))) + label = '\n'.join(filter(None, (type(self).__name__, self._name, self._shape_str(form=repr)))) cache[self] = node = DuplicatedLeafNode(label, (type(self).__name__, times[self])) return node diff --git a/nutils/testing.py b/nutils/testing.py index f030057e8..96e64a59f 100644 --- a/nutils/testing.py +++ b/nutils/testing.py @@ -162,6 +162,8 @@ def setUp(self): to ignore warnings for the entire class. ''' + maxDiff = None # prevent assertEqual from shortening the diff error message + def enter_context(self, ctx): retval = ctx.__enter__() self.addCleanup(ctx.__exit__, None, None, None) diff --git a/tests/test_evaluable.py b/tests/test_evaluable.py index c4da0531f..e09e5a4eb 100644 --- a/tests/test_evaluable.py +++ b/tests/test_evaluable.py @@ -57,6 +57,10 @@ def assertFunctionAlmostEqual(self, actual, desired, decimal): numpy.add.at(dense, indices, values) self.assertArrayAlmostEqual(dense, desired, decimal) + def test_str(self): + a = evaluable.Array((), shape=(2, 3), dtype=float) + self.assertEqual(str(a),'nutils.evaluable.Array') + def test_evalconst(self): constargs = [numpy.random.uniform(size=shape) for shape in self.shapes] self.assertFunctionAlmostEqual(decimal=15, @@ -681,17 +685,17 @@ class asciitree(TestCase): def test_asciitree(self): f = evaluable.Sin(evaluable.Inflate(1, evaluable.Zeros((), int), 2)**evaluable.Diagonalize(evaluable.Argument('arg', (2,)))) self.assertEqual(f.asciitree(richoutput=True), - '%0 = Sin; a2,a2\n' - '└ %1 = Power; a2,a2\n' - ' ├ %2 = Transpose; 1,0; i2,s2\n' - ' │ └ %3 = InsertAxis; s2,i2\n' - ' │ ├ %4 = Inflate; s2\n' + '%0 = Sin; f:a2,a2\n' + '└ %1 = Power; f:a2,a2\n' + ' ├ %2 = Transpose; 1,0; i:i2,s2\n' + ' │ └ %3 = InsertAxis; i:s2,i2\n' + ' │ ├ %4 = Inflate; i:s2\n' ' │ │ ├ 1\n' ' │ │ ├ 0\n' ' │ │ └ 2\n' ' │ └ 2\n' - ' └ %5 = Diagonalize; d2,d2\n' - ' └ Argument; arg; a2\n') + ' └ %5 = Diagonalize; f:d2,d2\n' + ' └ Argument; arg; f:a2\n') @unittest.skipIf(sys.version_info < (3, 6), 'test requires dicts maintaining insertion order') def test_loop_sum(self): @@ -717,19 +721,19 @@ def test_loop_concatenate(self): 'NODES\n' '%B0 = LoopConcatenate\n' '├ shape[0] = 2\n' - '├ start = %B1 = Take\n' - '│ ├ %A2 = _SizesToOffsets; a3\n' - '│ │ └ %A3 = InsertAxis; i2\n' + '├ start = %B1 = Take; i:\n' + '│ ├ %A2 = _SizesToOffsets; i:a3\n' + '│ │ └ %A3 = InsertAxis; i:i2\n' '│ │ ├ 1\n' '│ │ └ 2\n' '│ └ %B4 = LoopIndex\n' '│ └ length = 2\n' - '├ stop = %B5 = Take\n' + '├ stop = %B5 = Take; i:\n' '│ ├ %A2\n' - '│ └ %B6 = Add\n' + '│ └ %B6 = Add; i:\n' '│ ├ %B4\n' '│ └ 1\n' - '└ func = %B7 = InsertAxis; i1\n' + '└ func = %B7 = InsertAxis; i:i1\n' ' ├ %B4\n' ' └ 1\n') @@ -744,19 +748,19 @@ def test_loop_concatenatecombined(self): 'NODES\n' '%B0 = LoopConcatenate\n' '├ shape[0] = 2\n' - '├ start = %B1 = Take\n' - '│ ├ %A2 = _SizesToOffsets; a3\n' - '│ │ └ %A3 = InsertAxis; i2\n' + '├ start = %B1 = Take; i:\n' + '│ ├ %A2 = _SizesToOffsets; i:a3\n' + '│ │ └ %A3 = InsertAxis; i:i2\n' '│ │ ├ 1\n' '│ │ └ 2\n' '│ └ %B4 = LoopIndex\n' '│ └ length = 2\n' - '├ stop = %B5 = Take\n' + '├ stop = %B5 = Take; i:\n' '│ ├ %A2\n' - '│ └ %B6 = Add\n' + '│ └ %B6 = Add; i:\n' '│ ├ %B4\n' '│ └ 1\n' - '└ func = %B7 = InsertAxis; i1\n' + '└ func = %B7 = InsertAxis; i:i1\n' ' ├ %B4\n' ' └ 1\n')