diff --git a/test/test_array.py b/test/test_array.py index f6a3b209..15af9e6d 100644 --- a/test/test_array.py +++ b/test/test_array.py @@ -28,6 +28,7 @@ from arraycontext import ( dataclass_array_container, + flatten, pytest_generate_tests_for_array_contexts, with_container_arithmetic, ) @@ -142,33 +143,31 @@ def _get_test_containers(actx, ambient_dim=2): @pytest.mark.parametrize("ord", [2, np.inf]) def test_container_norm(actx_factory, ord): actx = actx_factory() - c_test = _get_test_containers(actx) - # {{{ actx.np.linalg.norm + # {{{ flat_norm from pytools.obj_array import make_obj_array c = MyContainer(name="hey", mass=1, momentum=make_obj_array([2, 3]), enthalpy=5) c_obj_ary = make_obj_array([c, c]) - n1 = actx.np.linalg.norm(c_obj_ary, ord) + n1 = flat_norm(c_obj_ary, ord) n2 = np.linalg.norm([1, 2, 3, 5]*2, ord) assert abs(n1 - n2) < 1e-12 - # }}} - - # {{{ flat_norm - - # check nested vs actx.np.linalg.norm - assert actx.to_numpy(abs( - flat_norm(c_test[1], ord=ord) - - actx.np.linalg.norm(c_test[1], ord=ord))) < 1e-12 - # check nested container with only Numbers (and no actx) assert abs(flat_norm(c_obj_ary, ord=ord) - n2) < 1.0e-12 assert abs( flat_norm(np.array([1, 1], dtype=object), ord=ord) - np.linalg.norm([1, 1], ord=ord)) < 1.0e-12 + + # check nested + n1 = actx.to_numpy(flat_norm(c_test[1], ord=ord)) + n2 = np.linalg.norm([ + np.linalg.norm(actx.to_numpy(flatten(ary, actx)), ord=ord) for ary in c_test[1] + ], ord=ord) + assert abs(n1 - n2) < 1e-12 + # }}} # }}}