From b1be58a5a463aa03d4b87a7034899c5fdbb3b94c Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 31 Jul 2023 20:59:55 -0500 Subject: [PATCH 01/66] Update _reference_derivative_matrices to recognize TensorProductElementGroup --- grudge/op.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index f5781f4be..09e17884b 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -259,16 +259,35 @@ def _reference_derivative_matrices(actx: ArrayContext, # _reference_stiffness_transpose_matrices. assert out_element_group is in_element_group + from meshmode.mesh import TensorProductElementGroup + @keyed_memoize_in( actx, _reference_derivative_matrices, lambda grp: grp.discretization_key()) def get_ref_derivative_mats(grp): - from meshmode.discretization.poly_element import diff_matrices - return actx.freeze( - actx.tag_axis( - 1, DiscretizationDOFAxisTag(), - actx.from_numpy( - np.asarray(diff_matrices(grp))))) + + if isinstance(grp, TensorProductElementGroup): + import modepy as mp + import numpy.linalg as la + + space1d = grp.space.bases[0] + shape1d = grp.shape.bases[0] + + nodes1d = mp.edge_clustered_nodes_for_space(space1d, shape1d) + basis1d = mp.basis_for_space(space1d, shape1d) + + vdm1d = mp.vandermonde(basis1d.functions, nodes1d) + vdm_p1d = mp.vandermonde(basis1d.gradients, nodes1d)[0] + + return actx.freeze(actx.from_numpy(vdm_p1d @ la.inv(vdm1d))) + + else: + from meshmode.discretization.poly_element import diff_matrices + return actx.freeze( + actx.tag_axis( + 1, DiscretizationDOFAxisTag(), + actx.from_numpy( + np.asarray(diff_matrices(grp))))) return get_ref_derivative_mats(out_element_group) From 8fa3321ff01bea49ba4f81b159a580d7e77e7956 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 31 Jul 2023 21:10:08 -0500 Subject: [PATCH 02/66] Stub in tensor product gradient computation in _gradient_kernel --- grudge/op.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 09e17884b..cfd40eed5 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -202,19 +202,33 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, *, metric_in_matvec): # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. + from meshmode.mesh import TensorProductElementGroup + + def compute_tensor_product_grad(actx, diff_mat, vec): + """Exploits tensor product structure to differentiate each coordinate + axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) + """ + pass + per_group_grads = [ + + compute_tensor_product_grad(actx, get_diff_mat, vec_i) + if isinstance(in_grp, TensorProductElementGroup) + # r for rst axis # x for xyz axis - actx.einsum("xrej,rij,ej->xei" if metric_in_matvec else "xrei,rij,ej->xei", - ijm_i, - get_diff_mat( - actx, - out_element_group=out_grp, - in_element_group=in_grp - ), - vec_i, - arg_names=("inv_jac_t", "ref_stiffT_mat", "vec"), - tagged=(FirstAxisIsElementsTag(),)) + else actx.einsum( + "xrej,rij,ej->xei" if metric_in_matvec else "xrei,rij,ej->xei", + ijm_i, + get_diff_mat( + actx, + out_element_group=out_grp, + in_element_group=in_grp + ), + vec_i, + arg_names=("inv_jac_t", "ref_stiffT_mat", "vec"), + tagged=(FirstAxisIsElementsTag(),)) + for out_grp, in_grp, vec_i, ijm_i in zip( out_discr.groups, in_discr.groups, vec, inv_jac_mat)] From 3778edfcfc872e058abf8f8655367b510e425632 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 31 Jul 2023 21:59:06 -0500 Subject: [PATCH 03/66] First version of grad routine --- grudge/op.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index cfd40eed5..3a3cd9919 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -204,15 +204,70 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, # (both strong and weak derivative) and their differences. from meshmode.mesh import TensorProductElementGroup - def compute_tensor_product_grad(actx, diff_mat, vec): + def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - pass + + from modepy.tools import ( + reshape_array_for_tensor_product_space, + unreshape_array_for_tensor_product_space) + + # reshape u to expose tensor product structure + vec = make_obj_array([ + reshape_array_for_tensor_product_space(grp.space, vec[i]) + for i in range(vec.shape[0]) + ]) + + # apply differentiation matrix to vec + if vec.shape[0] == 2: + specs = ["il,elj->eij", + "jl,eil->eij"] + elif vec.shape[1] == 3: + specs = ["il,eljk->eijk", + "jl,eilk->eijk", + "kl,eijl->eijk"] + else: + specs = None + assert specs is not None + + grad = make_obj_array([ + make_obj_array([ + actx.einsum( + spec, + diff_mat, + vec[i], + arg_names=("diff_mat", "vec"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + for i in range(vec.shape[0]) + ]) + for spec in specs + ]) + + # unreshape grad to apply geometric factors + # NOTE: In a future version, do not reshape before application of + # geometric factors. Can possibly "chain" the einsum as it is below + grad = make_obj_array([ + unreshape_array_for_tensor_product_space(grp.space, grad[i][0]) + for i in range(grad.shape[0]) + ]) + + # apply geometric factors to current grad + grad = make_obj_array([ + actx.einsum( + "rei,ei->ei", + ijm[i], + grad[i], + tagged=(FirstAxisIsElementsTag(),)) + for i in range(grad.shape[0]) + ]) + + return grad per_group_grads = [ - compute_tensor_product_grad(actx, get_diff_mat, vec_i) + compute_tensor_product_grad(actx, in_grp, get_diff_mat, vec_i, ijm_i) if isinstance(in_grp, TensorProductElementGroup) # r for rst axis From 45e859e5a788a68dbba2fbadaadd53edc231b523 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 1 Aug 2023 12:39:48 -0500 Subject: [PATCH 04/66] Initial working version of tensor product gradient operator application --- grudge/op.py | 73 +++++++++++++++++++++++++++++++++++++------------ test/test_op.py | 68 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 17 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 3a3cd9919..c6e6ce36e 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -166,6 +166,36 @@ ) +# {{{ Temporary tools for tensor product operators +from pytools.tag import Tag +class OutputIsTensorProductDOFArrayOrdered(Tag): + pass + + +from grudge.array_context import PyOpenCLArrayContext +class TensorProductArrayContext(PyOpenCLArrayContext): + def transform_loopy_program(self, t_unit): + if len(t_unit.callables_table) == 1: + knl = t_unit.default_entrypoint + if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + for arg in knl.args: + if arg.is_output: + arg = arg.copy(dim_tags=( + f"N{len(arg.shape)-1}," + + ",".join(f"N{i}" + for i in range(len(arg.shape)-1)) + )) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + t_unit = t_unit.with_kernel(knl) + + return super().transform_loopy_program(t_unit) +# }}} + + # {{{ common derivative "kernels" def _single_axis_derivative_kernel( @@ -202,28 +232,30 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, *, metric_in_matvec): # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. - from meshmode.mesh import TensorProductElementGroup def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ + actx_tp = TensorProductArrayContext( + actx.queue, + allocator=actx.allocator, + force_device_scalars=actx._force_device_scalars) + from modepy.tools import ( reshape_array_for_tensor_product_space, unreshape_array_for_tensor_product_space) # reshape u to expose tensor product structure - vec = make_obj_array([ - reshape_array_for_tensor_product_space(grp.space, vec[i]) - for i in range(vec.shape[0]) - ]) + vec = reshape_array_for_tensor_product_space(grp.space, vec) # apply differentiation matrix to vec - if vec.shape[0] == 2: + # check len(vec.shape) since shape is expected to be (nelements, ndofs) + if len(vec.shape) == 3: specs = ["il,elj->eij", "jl,eil->eij"] - elif vec.shape[1] == 3: + elif len(vec.shape) == 4: specs = ["il,eljk->eijk", "jl,eilk->eijk", "kl,eijl->eijk"] @@ -231,31 +263,34 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): specs = None assert specs is not None + diff_mat = get_diff_mat(actx, grp, grp) grad = make_obj_array([ - make_obj_array([ - actx.einsum( + actx_tp.einsum( spec, diff_mat, - vec[i], + vec, arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) - for i in range(vec.shape[0]) - ]) for spec in specs ]) # unreshape grad to apply geometric factors # NOTE: In a future version, do not reshape before application of - # geometric factors. Can possibly "chain" the einsum as it is below + # geometric factors. Can possibly "chain" the einsum. For example, the + # simplicial case below has einsum with spec + # ("xrei,rij,ei->ei") + # for the strong local gradient case grad = make_obj_array([ - unreshape_array_for_tensor_product_space(grp.space, grad[i][0]) + unreshape_array_for_tensor_product_space(grp.space, grad[i]) for i in range(grad.shape[0]) ]) # apply geometric factors to current grad + # FIXME: using einsum spec ("xrei,xei->xei") throws error: + # "Loopy does not directly support object arrays" grad = make_obj_array([ - actx.einsum( + actx_tp.einsum( "rei,ei->ei", ijm[i], grad[i], @@ -265,10 +300,12 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): return grad + from meshmode.discretization.poly_element import \ + TensorProductElementGroupBase per_group_grads = [ compute_tensor_product_grad(actx, in_grp, get_diff_mat, vec_i, ijm_i) - if isinstance(in_grp, TensorProductElementGroup) + if isinstance(in_grp, TensorProductElementGroupBase) # r for rst axis # x for xyz axis @@ -335,7 +372,9 @@ def _reference_derivative_matrices(actx: ArrayContext, lambda grp: grp.discretization_key()) def get_ref_derivative_mats(grp): - if isinstance(grp, TensorProductElementGroup): + from meshmode.discretization.poly_element import \ + TensorProductElementGroupBase + if isinstance(grp, TensorProductElementGroupBase): import modepy as mp import numpy.linalg as la diff --git a/test/test_op.py b/test/test_op.py index fa7ee0bbd..8b20b2bb8 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -32,6 +32,7 @@ import pytest +from grudge.discretization import make_discretization_collection from grudge.array_context import PytestPyOpenCLArrayContextFactory from arraycontext import pytest_generate_tests_for_array_contexts pytest_generate_tests = pytest_generate_tests_for_array_contexts( @@ -159,6 +160,73 @@ def get_flux(u_tpair): assert (eoc_rec.order_estimate() >= order - 0.5 or eoc_rec.max_error() < 1e-11) + +@pytest.mark.parametrize("form", ["strong"]) +@pytest.mark.parametrize("dim", [2]) +@pytest.mark.parametrize("order", [2]) +@pytest.mark.parametrize(("vectorize", "nested"), [ + (False, False) + ]) +def test_tensor_product_gradient(actx_factory, form, dim, order, vectorize, + nested, visualize=False): + + actx = actx_factory() + from pytools.convergence import EOCRecorder + eoc_rec = EOCRecorder() + + from meshmode.mesh import TensorProductElementGroup + from meshmode.discretization.poly_element import \ + LegendreGaussLobattoTensorProductGroupFactory as LGL + for n in [4, 6, 8]: + mesh = mgen.generate_regular_rect_mesh( + a=(-1,)*dim, + b=(1,)*dim, + nelements_per_axis=(n,)*dim, + group_cls=TensorProductElementGroup) + + import grudge.dof_desc as dd + dcoll = make_discretization_collection( + actx, + mesh, + discr_tag_to_group_factory={ + dd.DISCR_TAG_BASE: LGL(order)}) + + + def f(x): + ret = actx.np.cos(np.pi*x[0]) + actx.np.sin(np.pi*x[1]) + + if dim == 3: + ret = ret + actx.np.sin(np.pi*x[2]) + + return ret + + + def grad_f(x): + ret = make_obj_array([dcoll.zeros(actx) for _ in range(dim)]) + + ret[0] = -np.pi*actx.np.sin(np.pi*x[0]) + ret[1] = np.pi*actx.np.cos(np.pi*x[1]) + + if dim == 3: + ret[2] = np.pi*actx.np.cos(np.pi*x[2]) + + return ret + + + x = actx.thaw(dcoll.nodes()) + u = f(x) + ref_grad = grad_f(x) + grad = op.local_grad(dcoll, u) + + rel_linf_error = actx.to_numpy(op.norm(dcoll, ref_grad - grad, np.inf) / + op.norm(dcoll, ref_grad, np.inf)) + eoc_rec.add_data_point(1./n, rel_linf_error) + + print("L^inf error:") + print(eoc_rec) + assert (eoc_rec.order_estimate() >= order - 0.5 or + eoc_rec.max_error() < 1e-11) + # }}} From 22ffacdf969f15aa4f3018d1f287c02e1455adf5 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 1 Aug 2023 12:48:08 -0500 Subject: [PATCH 05/66] Add 3 dimensional test and order 3 test for 2D and 3D --- test/test_op.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/test/test_op.py b/test/test_op.py index 8b20b2bb8..1e45a4556 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -162,14 +162,16 @@ def get_flux(u_tpair): @pytest.mark.parametrize("form", ["strong"]) -@pytest.mark.parametrize("dim", [2]) -@pytest.mark.parametrize("order", [2]) +@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ (False, False) ]) def test_tensor_product_gradient(actx_factory, form, dim, order, vectorize, nested, visualize=False): - + """A "one-dimensional tensor product element" does not make sense, so the + one-dimensional case is excluded from this test. + """ actx = actx_factory() from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() @@ -193,10 +195,14 @@ def test_tensor_product_gradient(actx_factory, form, dim, order, vectorize, def f(x): - ret = actx.np.cos(np.pi*x[0]) + actx.np.sin(np.pi*x[1]) - - if dim == 3: - ret = ret + actx.np.sin(np.pi*x[2]) + if dim == 2: + ret = actx.np.cos(np.pi*x[0]) + actx.np.sin(np.pi*x[1]) + elif dim == 3: + ret = actx.np.cos(np.pi*x[0]) + actx.np.sin(np.pi*x[1]) \ + + actx.np.sin(np.pi*x[2]) + else: + ret = None + assert ret is not None return ret @@ -204,10 +210,12 @@ def f(x): def grad_f(x): ret = make_obj_array([dcoll.zeros(actx) for _ in range(dim)]) - ret[0] = -np.pi*actx.np.sin(np.pi*x[0]) - ret[1] = np.pi*actx.np.cos(np.pi*x[1]) - - if dim == 3: + if dim == 2: + ret[0] = -np.pi*actx.np.sin(np.pi*x[0]) + ret[1] = np.pi*actx.np.cos(np.pi*x[1]) + elif dim == 3: + ret[0] = -np.pi*actx.np.sin(np.pi*x[0]) + ret[1] = np.pi*actx.np.cos(np.pi*x[1]) ret[2] = np.pi*actx.np.cos(np.pi*x[2]) return ret From d0bd17e6bafce8bc98888bb9309bf4089ae295a0 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 1 Aug 2023 13:15:07 -0500 Subject: [PATCH 06/66] Add arg names to geometric factor application, refine some comments --- grudge/op.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index c6e6ce36e..c2f73d0e9 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -167,6 +167,8 @@ # {{{ Temporary tools for tensor product operators +# NOTE: Will possibly be removed in a future version of tensor product operator +# development since (I think) it is not entirely necessary from pytools.tag import Tag class OutputIsTensorProductDOFArrayOrdered(Tag): pass @@ -251,7 +253,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): vec = reshape_array_for_tensor_product_space(grp.space, vec) # apply differentiation matrix to vec - # check len(vec.shape) since shape is expected to be (nelements, ndofs) + # check len(vec.shape) since shape is expected to be + # (nelements, nnodes1d, nnodes1d) if len(vec.shape) == 3: specs = ["il,elj->eij", "jl,eil->eij"] @@ -294,7 +297,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): "rei,ei->ei", ijm[i], grad[i], - tagged=(FirstAxisIsElementsTag(),)) + tagged=(FirstAxisIsElementsTag(),)), + arg_names=("inv_jac_t", "vec") for i in range(grad.shape[0]) ]) From ef667bbcd6bc77b83cdeb85eb93952b21bfa3c95 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Wed, 2 Aug 2023 00:42:50 -0500 Subject: [PATCH 07/66] Divergence operator version 0.0 --- grudge/op.py | 96 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 84 insertions(+), 12 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index c2f73d0e9..449710ed8 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -297,8 +297,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): "rei,ei->ei", ijm[i], grad[i], - tagged=(FirstAxisIsElementsTag(),)), - arg_names=("inv_jac_t", "vec") + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", "vec")) for i in range(grad.shape[0]) ]) @@ -339,19 +339,91 @@ def _divergence_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec *, metric_in_matvec): # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. + + + def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): + """Exploits tensor product structure to differentiate each coordinate + axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) + """ + + actx_tp = TensorProductArrayContext( + actx.queue, + allocator=actx.allocator, + force_device_scalars=actx._force_device_scalars) + + from modepy.tools import ( + reshape_array_for_tensor_product_space, + unreshape_array_for_tensor_product_space) + + # reshape u to expose tensor product structure + vec = reshape_array_for_tensor_product_space(grp.space, vec) + + # define specs to extract dr, ds, dt + if len(vec.shape) == 3: + specs = ["il,elj->eij", + "jl,eil->eij"] + elif len(vec.shape) == 4: + specs = ["il,eljk->eijk", + "jl,eilk->eijk", + "kl,eijl->eijk"] + else: + specs = None + assert specs is not None + + diff_mat = get_diff_mat(actx, grp, grp) + drdsdt = make_obj_array([ + actx_tp.einsum( + spec, + diff_mat, + vec, + arg_names=("diff_mat", "vec"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + for spec in specs + ]) + + pu.db + if len(vec) == 3: + div = drdsdt[0] + drdsdt[1] + elif len(vec) == 4: + div = drdsdt[0] + drdsdt[1] + drdsdt[2] + else: + div = None + assert div is not None + + # see compute_tensor_product_grad for note on reshape before applying + # geometric factors + div = unreshape_array_for_tensor_product_space(grp.space, div) + + div = actx.einsum("xrei,ej->ej", + ijm, + div, + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", "vec")) + + return div + + + from meshmode.discretization.poly_element import \ + TensorProductElementGroupBase per_group_divs = [ + + compute_tensor_product_div(actx, in_grp, get_diff_mat, vec_i, ijm_i) + if isinstance(in_grp, TensorProductElementGroupBase) # r for rst axis # x for xyz axis - actx.einsum("xrej,rij,xej->ei" if metric_in_matvec else "xrei,rij,xej->ei", - ijm_i, - get_diff_mat( - actx, - out_element_group=out_grp, - in_element_group=in_grp - ), - vec_i, - arg_names=("inv_jac_t", "ref_stiffT_mat", "vec"), - tagged=(FirstAxisIsElementsTag(),)) + else actx.einsum( + "xrej,rij,xej->ei" if metric_in_matvec else "xrei,rij,xej->ei", + ijm_i, + get_diff_mat( + actx, + out_element_group=out_grp, + in_element_group=in_grp + ), + vec_i, + arg_names=("inv_jac_t", "ref_stiffT_mat", "vec"), + tagged=(FirstAxisIsElementsTag(),)) + for out_grp, in_grp, vec_i, ijm_i in zip( out_discr.groups, in_discr.groups, vec, inv_jac_mat)] From eed2516ee1c504768b78767c71728f5badde11ce Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 4 Aug 2023 15:26:33 -0500 Subject: [PATCH 08/66] Prototype of divergence kernel. Needs work, but it passes currently included convergence tests --- grudge/op.py | 106 ++++++++++++++++++++++++++---------------------- test/test_op.py | 78 +++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 48 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 449710ed8..ee5e536bd 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -42,6 +42,7 @@ """ from __future__ import annotations +from re import I __copyright__ = """ Copyright (C) 2021 Andreas Kloeckner @@ -79,6 +80,7 @@ DiscretizationDOFAxisTag, DiscretizationElementAxisTag, DiscretizationFaceAxisTag) +from meshmode.discretization.poly_element import TensorProductElementGroupBase from grudge.discretization import DiscretizationCollection from grudge.dof_desc import as_dofdesc @@ -235,6 +237,7 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. + def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) @@ -263,8 +266,9 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): "jl,eilk->eijk", "kl,eijl->eijk"] else: - specs = None - assert specs is not None + raise Exception("found dimension = {len(vec.shape)-1}. Special-case" + " tensor product operations are only valid for " + " 2 <= dimension <= 3.") diff_mat = get_diff_mat(actx, grp, grp) grad = make_obj_array([ @@ -274,7 +278,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): vec, arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + OutputIsTensorProductDOFArrayOrdered())) for spec in specs ]) @@ -304,8 +308,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): return grad - from meshmode.discretization.poly_element import \ - TensorProductElementGroupBase + per_group_grads = [ compute_tensor_product_grad(actx, in_grp, get_diff_mat, vec_i, ijm_i) @@ -341,7 +344,7 @@ def _divergence_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec # (both strong and weak derivative) and their differences. - def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): + def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ @@ -358,58 +361,73 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): # reshape u to expose tensor product structure vec = reshape_array_for_tensor_product_space(grp.space, vec) - # define specs to extract dr, ds, dt - if len(vec.shape) == 3: - specs = ["il,elj->eij", - "jl,eil->eij"] - elif len(vec.shape) == 4: - specs = ["il,eljk->eijk", - "jl,eilk->eijk", - "kl,eijl->eijk"] + # apply differentiation matrix to vec + # check len(vec.shape) since shape is expected to be + # (nelements, nnodes1d, nnodes1d) + # FIXME: make this "dimension independent" + if len(vec.shape) == 4: + specs = ["il,xelj->eij", + "jl,xeil->eij"] + elif len(vec.shape) == 5: + specs = ["il,xeljk->eijk", + "jl,xeilk->eijk", + "kl,xeijl->eijk"] else: - specs = None - assert specs is not None + raise Exception("found dimension = {len(vec.shape)-2}. Special-case" + " tensor product operations are only valid for " + " 2 <= dimension <= 3.") diff_mat = get_diff_mat(actx, grp, grp) - drdsdt = make_obj_array([ - actx_tp.einsum( + + # get partial derivatives for each ref. coord. axis + partials = make_obj_array([ + actx_tp.einsum( spec, diff_mat, vec, arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for spec in specs - ]) - - pu.db - if len(vec) == 3: - div = drdsdt[0] + drdsdt[1] - elif len(vec) == 4: - div = drdsdt[0] + drdsdt[1] + drdsdt[2] - else: - div = None - assert div is not None + OutputIsTensorProductDOFArrayOrdered())) + for spec in specs + ]) - # see compute_tensor_product_grad for note on reshape before applying - # geometric factors - div = unreshape_array_for_tensor_product_space(grp.space, div) + # unreshape partials to apply geometric factors + # NOTE: In a future version, do not reshape before application of + # geometric factors. Can possibly "chain" the einsum. For example, the + # simplicial case below has einsum with spec + # ("xrei,rij,xej->ei") + # for the strong local divergence case + partials = make_obj_array([ + unreshape_array_for_tensor_product_space(grp.space, partials[i]) + for i in range(partials.shape[0]) + ]) - div = actx.einsum("xrei,ej->ej", - ijm, - div, + # apply geometric factors to partial derivatives + # FIXME: using einsum spec ("xrei,xei->xei") throws error: + # "Loopy does not directly support object arrays" + partials = make_obj_array([ + actx_tp.einsum( + "rei,ei->ei", + ijm[i], + partials[i], tagged=(FirstAxisIsElementsTag(),), arg_names=("inv_jac_t", "vec")) + for i in range(partials.shape[0]) + ]) + + if partials.shape[0] == 2: + div = partials[0] + partials[1] + else: + div = partials[0] + partials[1] + partials[2] return div - from meshmode.discretization.poly_element import \ - TensorProductElementGroupBase per_group_divs = [ compute_tensor_product_div(actx, in_grp, get_diff_mat, vec_i, ijm_i) if isinstance(in_grp, TensorProductElementGroupBase) + # r for rst axis # x for xyz axis else actx.einsum( @@ -441,24 +459,16 @@ def _reference_derivative_matrices(actx: ArrayContext, # _reference_stiffness_transpose_matrices. assert out_element_group is in_element_group - from meshmode.mesh import TensorProductElementGroup - @keyed_memoize_in( actx, _reference_derivative_matrices, lambda grp: grp.discretization_key()) def get_ref_derivative_mats(grp): - - from meshmode.discretization.poly_element import \ - TensorProductElementGroupBase if isinstance(grp, TensorProductElementGroupBase): import modepy as mp import numpy.linalg as la - space1d = grp.space.bases[0] - shape1d = grp.shape.bases[0] - - nodes1d = mp.edge_clustered_nodes_for_space(space1d, shape1d) - basis1d = mp.basis_for_space(space1d, shape1d) + nodes1d = grp.unit_nodes_1d + basis1d = grp.basis_1d_obj() vdm1d = mp.vandermonde(basis1d.functions, nodes1d) vdm_p1d = mp.vandermonde(basis1d.gradients, nodes1d)[0] diff --git a/test/test_op.py b/test/test_op.py index 1e45a4556..90860faae 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -351,6 +351,84 @@ def get_flux(u_tpair): assert (eoc_rec.order_estimate() >= order - 0.5 or eoc_rec.max_error() < 1e-11) + +@pytest.mark.parametrize("form", ["strong"]) +@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize("order", [2, 3]) +@pytest.mark.parametrize(("vectorize", "nested"), [ + (False, False) + ]) +def test_tensor_product_divergence(actx_factory, form, dim, order, vectorize, + nested, visualize=False): + """A "one-dimensional tensor product element" does not make sense, so the + one-dimensional case is excluded from this test. + """ + actx = actx_factory() + from pytools.convergence import EOCRecorder + eoc_rec = EOCRecorder() + + from meshmode.mesh import TensorProductElementGroup + from meshmode.discretization.poly_element import \ + LegendreGaussLobattoTensorProductGroupFactory as LGL + for n in [4, 6, 8]: + mesh = mgen.generate_regular_rect_mesh( + a=(-1,)*dim, + b=(1,)*dim, + nelements_per_axis=(n,)*dim, + group_cls=TensorProductElementGroup) + + import grudge.dof_desc as dd + dcoll = make_discretization_collection( + actx, + mesh, + discr_tag_to_group_factory={ + dd.DISCR_TAG_BASE: LGL(order)}) + + + def f(x): + if dim == 2: + ret = make_obj_array([dcoll.empty(actx) for _ in range(dim)]) + ret[0] = actx.np.cos(np.pi*x[0]) + ret[1] = actx.np.sin(np.pi*x[1]) + + return ret + elif dim == 3: + ret = make_obj_array([dcoll.empty(actx) for _ in range(dim)]) + ret[0] = actx.np.cos(np.pi*x[0]) + ret[1] = actx.np.sin(np.pi*x[1]) + ret[2] = actx.np.sin(np.pi*x[2]) + + return ret + + + def div_f(x): + + if dim == 2: + ret = -np.pi*actx.np.sin(np.pi*x[0]) + \ + np.pi*actx.np.cos(np.pi*x[1]) + return ret + elif dim == 3: + ret = -np.pi*actx.np.sin(np.pi*x[0]) + \ + np.pi*actx.np.cos(np.pi*x[1]) + \ + np.pi*actx.np.cos(np.pi*x[2]) + + return ret + + + x = actx.thaw(dcoll.nodes()) + u = f(x) + ref_div = div_f(x) + div = op.local_div(dcoll, u) + + rel_linf_error = actx.to_numpy(op.norm(dcoll, ref_div - div, np.inf) / + op.norm(dcoll, ref_div, np.inf)) + eoc_rec.add_data_point(1./n, rel_linf_error) + + print("L^inf error:") + print(eoc_rec) + assert (eoc_rec.order_estimate() >= order - 0.5 or + eoc_rec.max_error() < 1e-11) + # }}} From 1e40a1199782d62543ba59f030382da10e5fe313 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 4 Aug 2023 15:28:19 -0500 Subject: [PATCH 09/66] Remove random import included by CoC autocomplete --- grudge/op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grudge/op.py b/grudge/op.py index ee5e536bd..39c99fd08 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -42,7 +42,6 @@ """ from __future__ import annotations -from re import I __copyright__ = """ Copyright (C) 2021 Andreas Kloeckner From f2b0275a3e7e6b242b691136f0c96e5c20b65c7a Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 5 Aug 2023 16:11:00 -0500 Subject: [PATCH 10/66] Generate einsum specification dynamically instead of using if-else --- grudge/op.py | 81 ++++++++++++++++++++++++++++------------------------ 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 39c99fd08..3447487c7 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -254,31 +254,36 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): # reshape u to expose tensor product structure vec = reshape_array_for_tensor_product_space(grp.space, vec) - # apply differentiation matrix to vec - # check len(vec.shape) since shape is expected to be - # (nelements, nnodes1d, nnodes1d) - if len(vec.shape) == 3: - specs = ["il,elj->eij", - "jl,eil->eij"] - elif len(vec.shape) == 4: - specs = ["il,eljk->eijk", - "jl,eilk->eijk", - "kl,eijl->eijk"] - else: - raise Exception("found dimension = {len(vec.shape)-1}. Special-case" - " tensor product operations are only valid for " - " 2 <= dimension <= 3.") + # apply differentiation matrix to function data + def pre_dims(axis): + return "ijk"[0:axis] + + + def post_dims(axis): + return "ijk"[axis+1:grp.dim] + + + def out_dims(): + return "ijk"[:grp.dim] + + + def axis(i): + return "ijk"[i] + diff_mat = get_diff_mat(actx, grp, grp) + # einsum specs will look something like: + # "il,eljk->eijk" (3D first coordinate partial) + # "jl,eil->eij" (2D second coordinate partial) grad = make_obj_array([ actx_tp.einsum( - spec, + f"{axis(i)}l,e{pre_dims(i)}l{post_dims(i)}->e{out_dims()}", diff_mat, vec, arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) - for spec in specs + for i in range(grp.dim) ]) # unreshape grad to apply geometric factors @@ -289,7 +294,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): # for the strong local gradient case grad = make_obj_array([ unreshape_array_for_tensor_product_space(grp.space, grad[i]) - for i in range(grad.shape[0]) + for i in range(grp.dim) ]) # apply geometric factors to current grad @@ -360,34 +365,36 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): # reshape u to expose tensor product structure vec = reshape_array_for_tensor_product_space(grp.space, vec) - # apply differentiation matrix to vec - # check len(vec.shape) since shape is expected to be - # (nelements, nnodes1d, nnodes1d) - # FIXME: make this "dimension independent" - if len(vec.shape) == 4: - specs = ["il,xelj->eij", - "jl,xeil->eij"] - elif len(vec.shape) == 5: - specs = ["il,xeljk->eijk", - "jl,xeilk->eijk", - "kl,xeijl->eijk"] - else: - raise Exception("found dimension = {len(vec.shape)-2}. Special-case" - " tensor product operations are only valid for " - " 2 <= dimension <= 3.") + # apply differentiation matrix to function data + def pre_dims(axis): + return "ijk"[0:axis] + + + def post_dims(axis): + return "ijk"[axis+1:grp.dim] + + + def out_dims(): + return "ijk"[:grp.dim] + + + def axis(i): + return "ijk"[i] - diff_mat = get_diff_mat(actx, grp, grp) # get partial derivatives for each ref. coord. axis + diff_mat = get_diff_mat(actx, grp, grp) + + # see comment on einsum spec in `compute_tensor_product_grad` partials = make_obj_array([ actx_tp.einsum( - spec, + f"{axis(i)}l,xe{pre_dims(i)}l{post_dims(i)}->e{out_dims()}", diff_mat, vec, arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) - for spec in specs + for i in range(grp.dim) ]) # unreshape partials to apply geometric factors @@ -469,8 +476,8 @@ def get_ref_derivative_mats(grp): nodes1d = grp.unit_nodes_1d basis1d = grp.basis_1d_obj() - vdm1d = mp.vandermonde(basis1d.functions, nodes1d) - vdm_p1d = mp.vandermonde(basis1d.gradients, nodes1d)[0] + vdm_1d = mp.vandermonde(basis1d.functions, nodes1d) + vdm_p_1d = mp.vandermonde(basis1d.gradients, nodes1d)[0] return actx.freeze(actx.from_numpy(vdm_p1d @ la.inv(vdm1d))) From 036681c01e8732aff8da627d2fc841ba7cb10298 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 5 Aug 2023 16:16:15 -0500 Subject: [PATCH 11/66] Rename vandermonde and vandermonde derivative matrices --- grudge/op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grudge/op.py b/grudge/op.py index 3447487c7..1e1c378a1 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -479,7 +479,7 @@ def get_ref_derivative_mats(grp): vdm_1d = mp.vandermonde(basis1d.functions, nodes1d) vdm_p_1d = mp.vandermonde(basis1d.gradients, nodes1d)[0] - return actx.freeze(actx.from_numpy(vdm_p1d @ la.inv(vdm1d))) + return actx.freeze(actx.from_numpy(vdm_p_1d @ la.inv(vdm_1d))) else: from meshmode.discretization.poly_element import diff_matrices From 7ad9017a92be718649bda3884c940d0abde18be3 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 7 Aug 2023 10:47:21 -0500 Subject: [PATCH 12/66] Give einsums a single source of truth. Still only valid for dim <= 3 --- grudge/op.py | 50 +++++++------------------------------------------- 1 file changed, 7 insertions(+), 43 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 1e1c378a1..784e60667 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -254,36 +254,18 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): # reshape u to expose tensor product structure vec = reshape_array_for_tensor_product_space(grp.space, vec) - # apply differentiation matrix to function data - def pre_dims(axis): - return "ijk"[0:axis] - - - def post_dims(axis): - return "ijk"[axis+1:grp.dim] - - - def out_dims(): - return "ijk"[:grp.dim] - - - def axis(i): - return "ijk"[i] - - + # apply operators to function data + dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) - # einsum specs will look something like: - # "il,eljk->eijk" (3D first coordinate partial) - # "jl,eil->eij" (2D second coordinate partial) grad = make_obj_array([ actx_tp.einsum( - f"{axis(i)}l,e{pre_dims(i)}l{post_dims(i)}->e{out_dims()}", + f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", diff_mat, vec, arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) - for i in range(grp.dim) + for i in range(dim) ]) # unreshape grad to apply geometric factors @@ -366,35 +348,17 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): vec = reshape_array_for_tensor_product_space(grp.space, vec) # apply differentiation matrix to function data - def pre_dims(axis): - return "ijk"[0:axis] - - - def post_dims(axis): - return "ijk"[axis+1:grp.dim] - - - def out_dims(): - return "ijk"[:grp.dim] - - - def axis(i): - return "ijk"[i] - - - # get partial derivatives for each ref. coord. axis + dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) - - # see comment on einsum spec in `compute_tensor_product_grad` partials = make_obj_array([ actx_tp.einsum( - f"{axis(i)}l,xe{pre_dims(i)}l{post_dims(i)}->e{out_dims()}", + f"ij,xe{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", diff_mat, vec, arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) - for i in range(grp.dim) + for i in range(dim) ]) # unreshape partials to apply geometric factors From c645dbe43f1620d50b6e89cfb0f1ffd02f0c8654 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 10 Aug 2023 15:17:24 -0500 Subject: [PATCH 13/66] Move TP array context to array_context.py, other minor changes --- grudge/array_context.py | 39 ++++++++++++++++++++++++++++++++++++ grudge/op.py | 44 +++++++++-------------------------------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index 2e82519e2..f29e8ef5d 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -127,6 +127,45 @@ def __init__(self, queue: "pyopencl.CommandQueue", # }}} +# {{{ Tensor product array context + +class OutputIsTensorProductDOFArrayOrdered(Tag): + """Signify that the strides will not be of order "C" or "F". See + :class:`grudge.array_context.TensorProductArrayContext` for more details. + """ + pass + + +class TensorProductArrayContext(_PyOpenCLArrayContextBase): + """Specialized array context for use with tensor product elements. + + The strides for the arrays containing tensor product element data are of the + form (slow, fastest, faster, fast). These strides are not "C" or "F" order. + Hence, this specialized array context takes care of specifying the + particular strides required. + """ + + def transform_loopy_program(self, t_unit): + if len(t_unit.callables_table) == 1: + knl = t_unit.default_entrypoint + if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + for arg in knl.args: + if arg.is_output: + arg = arg.copy(dim_tags=( + f"N{len(arg.shape)-1}," + + ",".join(f"N{i}" + for i in range(len(arg.shape)-1)) + )) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + t_unit = t_unit.with_kernel(knl) + + return super().transform_loopy_program(t_unit) +# }}} + # {{{ pytato diff --git a/grudge/op.py b/grudge/op.py index 784e60667..08eeacbd2 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -83,6 +83,9 @@ from grudge.discretization import DiscretizationCollection from grudge.dof_desc import as_dofdesc +from grudge.array_context import ( + TensorProductArrayContext, + OutputIsTensorProductDOFArrayOrdered) from pytools import keyed_memoize_in from pytools.obj_array import make_obj_array @@ -167,38 +170,6 @@ ) -# {{{ Temporary tools for tensor product operators -# NOTE: Will possibly be removed in a future version of tensor product operator -# development since (I think) it is not entirely necessary -from pytools.tag import Tag -class OutputIsTensorProductDOFArrayOrdered(Tag): - pass - - -from grudge.array_context import PyOpenCLArrayContext -class TensorProductArrayContext(PyOpenCLArrayContext): - def transform_loopy_program(self, t_unit): - if len(t_unit.callables_table) == 1: - knl = t_unit.default_entrypoint - if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): - new_args = [] - for arg in knl.args: - if arg.is_output: - arg = arg.copy(dim_tags=( - f"N{len(arg.shape)-1}," - + ",".join(f"N{i}" - for i in range(len(arg.shape)-1)) - )) - - new_args.append(arg) - - knl = knl.copy(args=new_args) - t_unit = t_unit.with_kernel(knl) - - return super().transform_loopy_program(t_unit) -# }}} - - # {{{ common derivative "kernels" def _single_axis_derivative_kernel( @@ -437,11 +408,14 @@ def get_ref_derivative_mats(grp): import modepy as mp import numpy.linalg as la + # not functional in current state nodes1d = grp.unit_nodes_1d - basis1d = grp.basis_1d_obj() + bases_1d = grp.bases_1d() - vdm_1d = mp.vandermonde(basis1d.functions, nodes1d) - vdm_p_1d = mp.vandermonde(basis1d.gradients, nodes1d)[0] + diff_mats = [] + for i in range(len(bases_1d)): + vdm_1d = mp.vandermonde(bases_1d.functions, nodes1d) + vdm_p_1d = mp.vandermonde(bases_1d.gradients, nodes1d)[0] return actx.freeze(actx.from_numpy(vdm_p_1d @ la.inv(vdm_1d))) From dcd7ca06b58bc4fd7d9deb2ebb738fc938eb565d Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 10 Aug 2023 18:01:40 -0500 Subject: [PATCH 14/66] Update tensor product grad test to match the other grad test case --- grudge/op.py | 7 +-- test/test_op.py | 124 ++++++++++++++++++++++++++++++++++-------------- 2 files changed, 91 insertions(+), 40 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 08eeacbd2..5a64cf624 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -408,14 +408,11 @@ def get_ref_derivative_mats(grp): import modepy as mp import numpy.linalg as la - # not functional in current state nodes1d = grp.unit_nodes_1d bases_1d = grp.bases_1d() - diff_mats = [] - for i in range(len(bases_1d)): - vdm_1d = mp.vandermonde(bases_1d.functions, nodes1d) - vdm_p_1d = mp.vandermonde(bases_1d.gradients, nodes1d)[0] + vdm_1d = mp.vandermonde(bases_1d.functions, nodes1d) + vdm_p_1d = mp.vandermonde(bases_1d.gradients, nodes1d)[0] return actx.freeze(actx.from_numpy(vdm_p_1d @ la.inv(vdm_1d))) diff --git a/test/test_op.py b/test/test_op.py index 90860faae..ec96209c4 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -165,75 +165,129 @@ def get_flux(u_tpair): @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ - (False, False) + (False, False), + (True, False), + (True, True) ]) -def test_tensor_product_gradient(actx_factory, form, dim, order, vectorize, +def test_tensor_product_gradient(actx_factory, form, dim, order, vectorize, nested, visualize=False): """A "one-dimensional tensor product element" does not make sense, so the one-dimensional case is excluded from this test. """ + actx = actx_factory() + from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() from meshmode.mesh import TensorProductElementGroup from meshmode.discretization.poly_element import \ - LegendreGaussLobattoTensorProductGroupFactory as LGL + LegendreGaussLobattoTensorProductGroupFactory as LGL for n in [4, 6, 8]: mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, - b=(1,)*dim, + a=(-1,)*dim, b=(1,)*dim, nelements_per_axis=(n,)*dim, group_cls=TensorProductElementGroup) import grudge.dof_desc as dd - dcoll = make_discretization_collection( + dcoll = DiscretizationCollection( actx, mesh, discr_tag_to_group_factory={ dd.DISCR_TAG_BASE: LGL(order)}) - def f(x): - if dim == 2: - ret = actx.np.cos(np.pi*x[0]) + actx.np.sin(np.pi*x[1]) - elif dim == 3: - ret = actx.np.cos(np.pi*x[0]) + actx.np.sin(np.pi*x[1]) \ - + actx.np.sin(np.pi*x[2]) - else: - ret = None - assert ret is not None + result = dcoll.zeros(actx) + 1 + for i in range(dim-1): + result = result * actx.np.sin(np.pi*x[i]) + result = result * actx.np.cos(np.pi/2*x[dim-1]) + return result - return ret + def grad_f(x): + result = make_obj_array([dcoll.zeros(actx) + 1 for _ in range(dim)]) + for i in range(dim-1): + for j in range(i): + result[i] = result[i] * actx.np.sin(np.pi*x[j]) + result[i] = result[i] * np.pi*actx.np.cos(np.pi*x[i]) + for j in range(i+1, dim-1): + result[i] = result[i] * actx.np.sin(np.pi*x[j]) + result[i] = result[i] * actx.np.cos(np.pi/2*x[dim-1]) + for j in range(dim-1): + result[dim-1] = result[dim-1] * actx.np.sin(np.pi*x[j]) + result[dim-1] = result[dim-1] * (-np.pi/2*actx.np.sin(np.pi/2*x[dim-1])) + return result + x = actx.thaw(dcoll.nodes()) - def grad_f(x): - ret = make_obj_array([dcoll.zeros(actx) for _ in range(dim)]) + if vectorize: + u = make_obj_array([(i+1)*f(x) for i in range(dim)]) + else: + u = f(x) - if dim == 2: - ret[0] = -np.pi*actx.np.sin(np.pi*x[0]) - ret[1] = np.pi*actx.np.cos(np.pi*x[1]) - elif dim == 3: - ret[0] = -np.pi*actx.np.sin(np.pi*x[0]) - ret[1] = np.pi*actx.np.cos(np.pi*x[1]) - ret[2] = np.pi*actx.np.cos(np.pi*x[2]) + def get_flux(u_tpair): + dd = u_tpair.dd + dd_allfaces = dd.with_dtag("all_faces") + normal = actx.thaw(dcoll.normal(dd)) + u_avg = u_tpair.avg + if vectorize: + if nested: + flux = make_obj_array([u_avg_i * normal for u_avg_i in u_avg]) + else: + flux = np.outer(u_avg, normal) + else: + flux = u_avg * normal + return op.project(dcoll, dd, dd_allfaces, flux) - return ret + dd_allfaces = DOFDesc("all_faces") + if form == "strong": + grad_u = ( + op.local_grad(dcoll, u, nested=nested) + # No flux terms because u doesn't have inter-el jumps + ) + elif form == "weak": + grad_u = op.inverse_mass(dcoll, + -op.weak_local_grad(dcoll, u, nested=nested) # pylint: disable=E1130 + + # noqa: W504 + op.face_mass(dcoll, + dd_allfaces, + # Note: no boundary flux terms here because u_ext == u_int == 0 + sum(get_flux(utpair) + for utpair in op.interior_trace_pairs(dcoll, u)) + ) + ) + else: + raise ValueError("Invalid form argument.") - x = actx.thaw(dcoll.nodes()) - u = f(x) - ref_grad = grad_f(x) - grad = op.local_grad(dcoll, u) + if vectorize: + expected_grad_u = make_obj_array( + [(i+1)*grad_f(x) for i in range(dim)]) + if not nested: + expected_grad_u = np.stack(expected_grad_u, axis=0) + else: + expected_grad_u = grad_f(x) - rel_linf_error = actx.to_numpy(op.norm(dcoll, ref_grad - grad, np.inf) / - op.norm(dcoll, ref_grad, np.inf)) - eoc_rec.add_data_point(1./n, rel_linf_error) + if visualize: + from grudge.shortcuts import make_visualizer + vis = make_visualizer(dcoll, vis_order=order if dim == 3 else dim+3) + + filename = (f"test_gradient_{form}_{dim}_{order}" + f"{'_vec' if vectorize else ''}{'_nested' if nested else ''}.vtu") + vis.write_vtk_file(filename, [ + ("u", u), + ("grad_u", grad_u), + ("expected_grad_u", expected_grad_u), + ], overwrite=True) + + rel_linf_err = actx.to_numpy( + op.norm(dcoll, grad_u - expected_grad_u, np.inf) + / op.norm(dcoll, expected_grad_u, np.inf)) + eoc_rec.add_data_point(1./n, rel_linf_err) print("L^inf error:") print(eoc_rec) - assert (eoc_rec.order_estimate() >= order - 0.5 or - eoc_rec.max_error() < 1e-11) + assert (eoc_rec.order_estimate() >= order - 0.5 + or eoc_rec.max_error() < 1e-11) # }}} From 683cdd839df7b48781eed59f1bdfc1fc6f0c5f3b Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 10 Aug 2023 18:14:22 -0500 Subject: [PATCH 15/66] Update tensor product divergence test to match original test case. --- test/test_op.py | 121 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 38 deletions(-) diff --git a/test/test_op.py b/test/test_op.py index ec96209c4..92309b3a2 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -169,7 +169,7 @@ def get_flux(u_tpair): (True, False), (True, True) ]) -def test_tensor_product_gradient(actx_factory, form, dim, order, vectorize, +def test_tensor_product_gradient(actx_factory, form, dim, order, vectorize, nested, visualize=False): """A "one-dimensional tensor product element" does not make sense, so the one-dimensional case is excluded from this test. @@ -410,14 +410,17 @@ def get_flux(u_tpair): @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ - (False, False) + (False, False), + (True, False), + (True, True) ]) def test_tensor_product_divergence(actx_factory, form, dim, order, vectorize, - nested, visualize=False): + nested, visualize=False): """A "one-dimensional tensor product element" does not make sense, so the one-dimensional case is excluded from this test. """ actx = actx_factory() + from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() @@ -426,63 +429,105 @@ def test_tensor_product_divergence(actx_factory, form, dim, order, vectorize, LegendreGaussLobattoTensorProductGroupFactory as LGL for n in [4, 6, 8]: mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, - b=(1,)*dim, + a=(-1,)*dim, b=(1,)*dim, nelements_per_axis=(n,)*dim, group_cls=TensorProductElementGroup) import grudge.dof_desc as dd - dcoll = make_discretization_collection( + dcoll = DiscretizationCollection( actx, mesh, discr_tag_to_group_factory={ dd.DISCR_TAG_BASE: LGL(order)}) - def f(x): - if dim == 2: - ret = make_obj_array([dcoll.empty(actx) for _ in range(dim)]) - ret[0] = actx.np.cos(np.pi*x[0]) - ret[1] = actx.np.sin(np.pi*x[1]) + result = make_obj_array([dcoll.zeros(actx) + (i+1) for i in range(dim)]) + for i in range(dim-1): + result = result * actx.np.sin(np.pi*x[i]) + result = result * actx.np.cos(np.pi/2*x[dim-1]) + return result - return ret - elif dim == 3: - ret = make_obj_array([dcoll.empty(actx) for _ in range(dim)]) - ret[0] = actx.np.cos(np.pi*x[0]) - ret[1] = actx.np.sin(np.pi*x[1]) - ret[2] = actx.np.sin(np.pi*x[2]) + def div_f(x): + result = dcoll.zeros(actx) + for i in range(dim-1): + deriv = dcoll.zeros(actx) + (i+1) + for j in range(i): + deriv = deriv * actx.np.sin(np.pi*x[j]) + deriv = deriv * np.pi*actx.np.cos(np.pi*x[i]) + for j in range(i+1, dim-1): + deriv = deriv * actx.np.sin(np.pi*x[j]) + deriv = deriv * actx.np.cos(np.pi/2*x[dim-1]) + result = result + deriv + deriv = dcoll.zeros(actx) + dim + for j in range(dim-1): + deriv = deriv * actx.np.sin(np.pi*x[j]) + deriv = deriv * (-np.pi/2*actx.np.sin(np.pi/2*x[dim-1])) + result = result + deriv + return result - return ret + x = actx.thaw(dcoll.nodes()) + if vectorize: + u = make_obj_array([(i+1)*f(x) for i in range(dim)]) + if not nested: + u = np.stack(u, axis=0) + else: + u = f(x) - def div_f(x): + def get_flux(u_tpair): + dd = u_tpair.dd + dd_allfaces = dd.with_dtag("all_faces") + normal = actx.thaw(dcoll.normal(dd)) + flux = u_tpair.avg @ normal + return op.project(dcoll, dd, dd_allfaces, flux) - if dim == 2: - ret = -np.pi*actx.np.sin(np.pi*x[0]) + \ - np.pi*actx.np.cos(np.pi*x[1]) - return ret - elif dim == 3: - ret = -np.pi*actx.np.sin(np.pi*x[0]) + \ - np.pi*actx.np.cos(np.pi*x[1]) + \ - np.pi*actx.np.cos(np.pi*x[2]) + dd_allfaces = DOFDesc("all_faces") - return ret + if form == "strong": + div_u = ( + op.local_div(dcoll, u) + # No flux terms because u doesn't have inter-el jumps + ) + elif form == "weak": + div_u = op.inverse_mass(dcoll, + -op.weak_local_div(dcoll, u) + + # noqa: W504 + op.face_mass(dcoll, + dd_allfaces, + # Note: no boundary flux terms here because u_ext == u_int == 0 + sum(get_flux(utpair) + for utpair in op.interior_trace_pairs(dcoll, u)) + ) + ) + else: + raise ValueError("Invalid form argument.") + if vectorize: + expected_div_u = make_obj_array([(i+1)*div_f(x) for i in range(dim)]) + else: + expected_div_u = div_f(x) - x = actx.thaw(dcoll.nodes()) - u = f(x) - ref_div = div_f(x) - div = op.local_div(dcoll, u) + if visualize: + from grudge.shortcuts import make_visualizer + vis = make_visualizer(dcoll, vis_order=order if dim == 3 else dim+3) - rel_linf_error = actx.to_numpy(op.norm(dcoll, ref_div - div, np.inf) / - op.norm(dcoll, ref_div, np.inf)) - eoc_rec.add_data_point(1./n, rel_linf_error) + filename = (f"test_divergence_{form}_{dim}_{order}" + f"{'_vec' if vectorize else ''}{'_nested' if nested else ''}.vtu") + vis.write_vtk_file(filename, [ + ("u", u), + ("div_u", div_u), + ("expected_div_u", expected_div_u), + ], overwrite=True) + + rel_linf_err = actx.to_numpy( + op.norm(dcoll, div_u - expected_div_u, np.inf) + / op.norm(dcoll, expected_div_u, np.inf)) + eoc_rec.add_data_point(1./n, rel_linf_err) print("L^inf error:") print(eoc_rec) - assert (eoc_rec.order_estimate() >= order - 0.5 or - eoc_rec.max_error() < 1e-11) - + assert (eoc_rec.order_estimate() >= order - 0.5 + or eoc_rec.max_error() < 1e-11) # }}} From 15639503e3e1e097bcfe938d5ded17605e8ade66 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 14 Aug 2023 13:05:09 -0500 Subject: [PATCH 16/66] Divergence kernel functioning again --- grudge/op.py | 50 ++++++++++++++++++------------------------------- test/test_op.py | 5 +++-- 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 5a64cf624..9431ac901 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -250,18 +250,14 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): for i in range(grp.dim) ]) - # apply geometric factors to current grad - # FIXME: using einsum spec ("xrei,xei->xei") throws error: - # "Loopy does not directly support object arrays" - grad = make_obj_array([ - actx_tp.einsum( - "rei,ei->ei", - ijm[i], - grad[i], - tagged=(FirstAxisIsElementsTag(),), - arg_names=("inv_jac_t", "vec")) - for i in range(grad.shape[0]) - ]) + # apply geometric factors + grad = actx.np.stack([grad[i] for i in range(dim)]) + grad = actx.einsum( + "xrei,xei->xei", + ijm, + grad, + arg_names=("inv_jac_t", "vec"), + tagged=(FirstAxisIsElementsTag(),)) return grad @@ -305,7 +301,6 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - actx_tp = TensorProductArrayContext( actx.queue, allocator=actx.allocator, @@ -323,9 +318,9 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): diff_mat = get_diff_mat(actx, grp, grp) partials = make_obj_array([ actx_tp.einsum( - f"ij,xe{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", + f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", diff_mat, - vec, + vec[i], arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -343,23 +338,14 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): for i in range(partials.shape[0]) ]) - # apply geometric factors to partial derivatives - # FIXME: using einsum spec ("xrei,xei->xei") throws error: - # "Loopy does not directly support object arrays" - partials = make_obj_array([ - actx_tp.einsum( - "rei,ei->ei", - ijm[i], - partials[i], - tagged=(FirstAxisIsElementsTag(),), - arg_names=("inv_jac_t", "vec")) - for i in range(partials.shape[0]) - ]) - - if partials.shape[0] == 2: - div = partials[0] + partials[1] - else: - div = partials[0] + partials[1] + partials[2] + # apply geometric factors + partials = actx.np.stack([partials[i] for i in range(dim)]) + div = actx.einsum( + "xrei,xei->ei", + ijm, + partials, + arg_names=("inv_jac_t", "vec"), + tagged=(FirstAxisIsElementsTag(),)) return div diff --git a/test/test_op.py b/test/test_op.py index 92309b3a2..e836a63e1 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -429,12 +429,13 @@ def test_tensor_product_divergence(actx_factory, form, dim, order, vectorize, LegendreGaussLobattoTensorProductGroupFactory as LGL for n in [4, 6, 8]: mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, b=(1,)*dim, + a=(-1,)*dim, + b=(1,)*dim, nelements_per_axis=(n,)*dim, group_cls=TensorProductElementGroup) import grudge.dof_desc as dd - dcoll = DiscretizationCollection( + dcoll = make_discretization_collection( actx, mesh, discr_tag_to_group_factory={ From e1380fed4fb48d8a8cca4e44cb4aa5fbd72b6d54 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 20 Aug 2023 20:42:36 -0500 Subject: [PATCH 17/66] Update some comments, begin weak form matrices work --- grudge/op.py | 33 +++++++++++++++++++++++---------- requirements.txt | 2 +- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 9431ac901..951249f11 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -240,17 +240,13 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): ]) # unreshape grad to apply geometric factors - # NOTE: In a future version, do not reshape before application of - # geometric factors. Can possibly "chain" the einsum. For example, the - # simplicial case below has einsum with spec - # ("xrei,rij,ei->ei") - # for the strong local gradient case grad = make_obj_array([ unreshape_array_for_tensor_product_space(grp.space, grad[i]) for i in range(grp.dim) ]) # apply geometric factors + # TODO: chain the einsum above with the einsum below grad = actx.np.stack([grad[i] for i in range(dim)]) grad = actx.einsum( "xrei,xei->xei", @@ -328,11 +324,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): ]) # unreshape partials to apply geometric factors - # NOTE: In a future version, do not reshape before application of - # geometric factors. Can possibly "chain" the einsum. For example, the - # simplicial case below has einsum with spec - # ("xrei,rij,xej->ei") - # for the strong local divergence case + # TODO: chain the einsum above with the einsum below partials = make_obj_array([ unreshape_array_for_tensor_product_space(grp.space, partials[i]) for i in range(partials.shape[0]) @@ -579,7 +571,28 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): from meshmode.discretization.poly_element import \ mass_matrix, diff_matrices + if isinstance(out_grp, TensorProductElementGroupBase): + import modepy as mp + import numpy.linalg as la + + basis_1d = out_grp.bases_1d() + nodes_1d = out_grp.unit_nodes_1d + + vdm = mp.vandermonde(basis_1d.functions, nodes_1d) + vdm_p = mp.vandermonde(basis_1d.gradients, nodes_1d)[0] + + # NOTE: possibly work special-case matrices like differentiation + # matrix, mass matrix, into modepy + mmat = la.inv(vdm @ vdm.T) + diff_mat = vdm_p @ la.inv(vdm) + return actx.freeze( + actx.tag_axis(1, DiscretizationDOFAxisTag(), + actx.from_numpy( + np.asarray( + diff_mat.T @ mmat.T)))) + mmat = mass_matrix(out_grp) + return actx.freeze( actx.tag_axis(1, DiscretizationDOFAxisTag(), actx.from_numpy( diff --git a/requirements.txt b/requirements.txt index 2107e5aeb..f56f10888 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ git+https://github.com/inducer/leap.git#egg=leap git+https://github.com/inducer/meshpy.git#egg=meshpy git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/arraycontext.git#egg=arraycontext -git+https://github.com/inducer/meshmode.git#egg=meshmode +git+https://github.com/a-alveyblanc/meshmode.git@tensor-product-1d-nodes-and-1d-basis#egg=meshmode git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/pymetis.git#egg=pymetis git+https://github.com/illinois-ceesd/logpyle.git#egg=logpyle From 33a54e40c7966a5cfc24108cd8782ede7ddc0035 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 6 Sep 2023 11:31:04 -0500 Subject: [PATCH 18/66] TMP: Use outside actx in TP grad --- grudge/op.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 951249f11..212200600 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -213,10 +213,10 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - actx_tp = TensorProductArrayContext( - actx.queue, - allocator=actx.allocator, - force_device_scalars=actx._force_device_scalars) + # actx_tp = TensorProductArrayContext( + # actx.queue, + # allocator=actx.allocator, + # force_device_scalars=actx._force_device_scalars) from modepy.tools import ( reshape_array_for_tensor_product_space, @@ -229,7 +229,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) grad = make_obj_array([ - actx_tp.einsum( + actx.einsum( f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", diff_mat, vec, From e446bb76abc548014e54412fab45a45c8d7c01dc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 6 Sep 2023 11:31:21 -0500 Subject: [PATCH 19/66] Add TP transform cartoon --- examples/tp-transform-cartoon.py | 56 ++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 examples/tp-transform-cartoon.py diff --git a/examples/tp-transform-cartoon.py b/examples/tp-transform-cartoon.py new file mode 100644 index 000000000..7b2472076 --- /dev/null +++ b/examples/tp-transform-cartoon.py @@ -0,0 +1,56 @@ +import numpy as np +import pyopencl as cl +from meshmode.array_context import PytatoPyOpenCLArrayContext +import meshmode.mesh.generation as mgen +from grudge import op, DiscretizationCollection +from pytools.obj_array import make_obj_array + + +class MyArrayContext(PytatoPyOpenCLArrayContext): + pass + + +def main(): + order = 4 + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + actx = MyArrayContext(queue) + + dim = 3 + n = 5 + + from meshmode.mesh import TensorProductElementGroup + from meshmode.discretization.poly_element import \ + LegendreGaussLobattoTensorProductGroupFactory as LGL + + mesh = mgen.generate_regular_rect_mesh( + a=(-1,)*dim, b=(1,)*dim, + nelements_per_axis=(n,)*dim, + group_cls=TensorProductElementGroup) + + import grudge.dof_desc as dd + dcoll = DiscretizationCollection( + actx, + mesh, + discr_tag_to_group_factory={ + dd.DISCR_TAG_BASE: LGL(order)}) + + def f(x): + result = dcoll.zeros(actx) + 1 + for i in range(dim-1): + result = result * actx.np.sin(np.pi*x[i]) + result = result * actx.np.cos(np.pi/2*x[dim-1]) + return result + + + x = actx.thaw(dcoll.nodes()) + + u = f(x) + + op.local_grad(dcoll, u) + + +if __name__ == "__main__": + main() + From 264192c1d4dbec10e2b394b80caebb394b3859a4 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 12 Sep 2023 10:54:57 -0500 Subject: [PATCH 20/66] Temporary changes to get tensor product gradient working again --- examples/tp-transform-cartoon.py | 2 +- grudge/op.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/tp-transform-cartoon.py b/examples/tp-transform-cartoon.py index 7b2472076..cbc23c267 100644 --- a/examples/tp-transform-cartoon.py +++ b/examples/tp-transform-cartoon.py @@ -48,7 +48,7 @@ def f(x): u = f(x) - op.local_grad(dcoll, u) + grad_u = op.local_grad(dcoll, u) if __name__ == "__main__": diff --git a/grudge/op.py b/grudge/op.py index 212200600..0b2e95a7d 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -213,10 +213,10 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - # actx_tp = TensorProductArrayContext( - # actx.queue, - # allocator=actx.allocator, - # force_device_scalars=actx._force_device_scalars) + actx_tp = TensorProductArrayContext( + actx.queue, + allocator=actx.allocator, + force_device_scalars=actx._force_device_scalars) from modepy.tools import ( reshape_array_for_tensor_product_space, @@ -229,7 +229,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) grad = make_obj_array([ - actx.einsum( + actx_tp.einsum( f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", diff_mat, vec, @@ -247,8 +247,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): # apply geometric factors # TODO: chain the einsum above with the einsum below - grad = actx.np.stack([grad[i] for i in range(dim)]) - grad = actx.einsum( + grad = actx_tp.np.stack([grad[i] for i in range(dim)]) + grad = actx_tp.einsum( "xrei,xei->xei", ijm, grad, From ba03b3fab772fc1b78fc05d2d712ae1943c81a06 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Wed, 13 Sep 2023 18:42:07 -0500 Subject: [PATCH 21/66] Tensor product array context related changes --- examples/tp-transform-cartoon.py | 28 +++++++++++++++++++++++++--- grudge/op.py | 21 ++++++--------------- test/test_op.py | 22 +++++++++++++++------- 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/examples/tp-transform-cartoon.py b/examples/tp-transform-cartoon.py index cbc23c267..4ebade58b 100644 --- a/examples/tp-transform-cartoon.py +++ b/examples/tp-transform-cartoon.py @@ -3,11 +3,32 @@ from meshmode.array_context import PytatoPyOpenCLArrayContext import meshmode.mesh.generation as mgen from grudge import op, DiscretizationCollection +from grudge.array_context import OutputIsTensorProductDOFArrayOrdered from pytools.obj_array import make_obj_array -class MyArrayContext(PytatoPyOpenCLArrayContext): - pass +class PytatoTensorProductArrayContext(PytatoPyOpenCLArrayContext): + def transform_loopy_program(self, t_unit): + + if len(t_unit.callables_table) == 1: + knl = t_unit.default_entrypoint + + if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + for arg in knl.args: + if arg.is_output: + arg = arg.copy(dim_tags=( + f"N{len(arg.shape)-1}," + + ",".join(f"N{i}" + for i in range(len(arg.shape)-1)) + )) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + t_unit = t_unit.with_kernel(knl) + + return super().transform_loopy_program(t_unit) def main(): @@ -15,7 +36,7 @@ def main(): ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) - actx = MyArrayContext(queue) + actx = PytatoTensorProductArrayContext(queue) dim = 3 n = 5 @@ -50,6 +71,7 @@ def f(x): grad_u = op.local_grad(dcoll, u) + pu.db if __name__ == "__main__": main() diff --git a/grudge/op.py b/grudge/op.py index 0b2e95a7d..b60ab62d7 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -213,11 +213,6 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - actx_tp = TensorProductArrayContext( - actx.queue, - allocator=actx.allocator, - force_device_scalars=actx._force_device_scalars) - from modepy.tools import ( reshape_array_for_tensor_product_space, unreshape_array_for_tensor_product_space) @@ -229,7 +224,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) grad = make_obj_array([ - actx_tp.einsum( + actx.einsum( f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", diff_mat, vec, @@ -247,8 +242,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): # apply geometric factors # TODO: chain the einsum above with the einsum below - grad = actx_tp.np.stack([grad[i] for i in range(dim)]) - grad = actx_tp.einsum( + grad = actx.np.stack([grad[i] for i in range(dim)]) + grad = actx.einsum( "xrei,xei->xei", ijm, grad, @@ -297,10 +292,6 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - actx_tp = TensorProductArrayContext( - actx.queue, - allocator=actx.allocator, - force_device_scalars=actx._force_device_scalars) from modepy.tools import ( reshape_array_for_tensor_product_space, @@ -313,7 +304,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) partials = make_obj_array([ - actx_tp.einsum( + actx.einsum( f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", diff_mat, vec[i], @@ -321,14 +312,14 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) for i in range(dim) - ]) + ]) # unreshape partials to apply geometric factors # TODO: chain the einsum above with the einsum below partials = make_obj_array([ unreshape_array_for_tensor_product_space(grp.space, partials[i]) for i in range(partials.shape[0]) - ]) + ]) # apply geometric factors partials = actx.np.stack([partials[i] for i in range(dim)]) diff --git a/test/test_op.py b/test/test_op.py index e836a63e1..05f6c1f99 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -169,13 +169,18 @@ def get_flux(u_tpair): (True, False), (True, True) ]) -def test_tensor_product_gradient(actx_factory, form, dim, order, vectorize, +def test_tensor_product_gradient(form, dim, order, vectorize, nested, visualize=False): """A "one-dimensional tensor product element" does not make sense, so the one-dimensional case is excluded from this test. """ - actx = actx_factory() + import pyopencl as cl + from grudge.array_context import TensorProductArrayContext + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + actx = TensorProductArrayContext(queue) from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() @@ -303,9 +308,7 @@ def get_flux(u_tpair): (True, True) ]) def test_divergence(actx_factory, form, dim, order, vectorize, nested, - visualize=False): - actx = actx_factory() - + visualize=False): from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() @@ -414,12 +417,17 @@ def get_flux(u_tpair): (True, False), (True, True) ]) -def test_tensor_product_divergence(actx_factory, form, dim, order, vectorize, +def test_tensor_product_divergence(form, dim, order, vectorize, nested, visualize=False): """A "one-dimensional tensor product element" does not make sense, so the one-dimensional case is excluded from this test. """ - actx = actx_factory() + import pyopencl as cl + from grudge.array_context import TensorProductArrayContext + + ctx = cl.create_some_context() + queue = cl.CommandQueue(ctx) + actx = TensorProductArrayContext(queue) from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() From 8263333024437c6b4e0bab7fe6dcf8d2d60ff9f4 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 14 Sep 2023 09:28:02 -0500 Subject: [PATCH 22/66] Update example --- examples/tp-transform-cartoon.py | 33 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/examples/tp-transform-cartoon.py b/examples/tp-transform-cartoon.py index 4ebade58b..a3f54e210 100644 --- a/examples/tp-transform-cartoon.py +++ b/examples/tp-transform-cartoon.py @@ -4,29 +4,26 @@ import meshmode.mesh.generation as mgen from grudge import op, DiscretizationCollection from grudge.array_context import OutputIsTensorProductDOFArrayOrdered -from pytools.obj_array import make_obj_array class PytatoTensorProductArrayContext(PytatoPyOpenCLArrayContext): def transform_loopy_program(self, t_unit): - if len(t_unit.callables_table) == 1: - knl = t_unit.default_entrypoint + knl = t_unit.default_entrypoint + if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + for arg in knl.args: + if arg.is_output: + arg = arg.copy(dim_tags=( + f"N{len(arg.shape)-1}," + + ",".join(f"N{i}" + for i in range(len(arg.shape)-1)) + )) - if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): - new_args = [] - for arg in knl.args: - if arg.is_output: - arg = arg.copy(dim_tags=( - f"N{len(arg.shape)-1}," - + ",".join(f"N{i}" - for i in range(len(arg.shape)-1)) - )) + new_args.append(arg) - new_args.append(arg) - - knl = knl.copy(args=new_args) - t_unit = t_unit.with_kernel(knl) + knl = knl.copy(args=new_args) + t_unit = t_unit.with_kernel(knl) return super().transform_loopy_program(t_unit) @@ -39,7 +36,7 @@ def main(): actx = PytatoTensorProductArrayContext(queue) dim = 3 - n = 5 + res = 5 from meshmode.mesh import TensorProductElementGroup from meshmode.discretization.poly_element import \ @@ -47,7 +44,7 @@ def main(): mesh = mgen.generate_regular_rect_mesh( a=(-1,)*dim, b=(1,)*dim, - nelements_per_axis=(n,)*dim, + nelements_per_axis=(res,)*dim, group_cls=TensorProductElementGroup) import grudge.dof_desc as dd From 6b5002846026dea998a52d686a00d094c8a22722 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 14 Sep 2023 13:10:02 -0500 Subject: [PATCH 23/66] Add code for printing generated differentiation code --- examples/tp-transform-cartoon.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/tp-transform-cartoon.py b/examples/tp-transform-cartoon.py index a3f54e210..76ce6d889 100644 --- a/examples/tp-transform-cartoon.py +++ b/examples/tp-transform-cartoon.py @@ -1,5 +1,7 @@ import numpy as np import pyopencl as cl +import pytato as pt +import loopy as lp from meshmode.array_context import PytatoPyOpenCLArrayContext import meshmode.mesh.generation as mgen from grudge import op, DiscretizationCollection @@ -24,7 +26,6 @@ def transform_loopy_program(self, t_unit): knl = knl.copy(args=new_args) t_unit = t_unit.with_kernel(knl) - return super().transform_loopy_program(t_unit) @@ -67,7 +68,12 @@ def f(x): u = f(x) grad_u = op.local_grad(dcoll, u) + grad_u = actx.np.stack(grad_u)[0] + + prg = pt.generate_loopy(grad_u).program + code = lp.generate_code_v2(prg).device_code() + print(code) pu.db if __name__ == "__main__": From 5d36bfb916f9405e8d7728feda8a1c5ad6b703c1 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 18 Sep 2023 09:42:41 -0500 Subject: [PATCH 24/66] Update strong tp diff example --- examples/tp-transform-cartoon.py | 51 ++++++++++++++++++++++---------- grudge/op.py | 13 ++++++-- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/examples/tp-transform-cartoon.py b/examples/tp-transform-cartoon.py index 76ce6d889..b0bb13cfd 100644 --- a/examples/tp-transform-cartoon.py +++ b/examples/tp-transform-cartoon.py @@ -1,17 +1,31 @@ +import loopy as lp + +import meshmode.mesh.generation as mgen + import numpy as np import pyopencl as cl import pytato as pt -import loopy as lp -from meshmode.array_context import PytatoPyOpenCLArrayContext -import meshmode.mesh.generation as mgen -from grudge import op, DiscretizationCollection + +from grudge import op from grudge.array_context import OutputIsTensorProductDOFArrayOrdered +from grudge.discretization import make_discretization_collection + +from meshmode.array_context import PytatoPyOpenCLArrayContext class PytatoTensorProductArrayContext(PytatoPyOpenCLArrayContext): - def transform_loopy_program(self, t_unit): + def transform_dag(self, dag): + if "dag_dots" not in dir(self): + self.dag_dots = [] + self.dag_dots.append(pt.get_dot_graph(dag)) + + return super().transform_dag(dag) + + def transform_loopy_program(self, t_unit): knl = t_unit.default_entrypoint + + # {{{ adjust strides according to tensor product structure if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): new_args = [] for arg in knl.args: @@ -25,19 +39,31 @@ def transform_loopy_program(self, t_unit): new_args.append(arg) knl = knl.copy(args=new_args) - t_unit = t_unit.with_kernel(knl) + # }}} + + # {{{ prefetch + # }}} + + # {{{ tile + # }}} + + # FIXME: remove this (eventually) + knl = lp.set_options(knl, insert_gbarriers=True) + t_unit = t_unit.with_kernel(knl) + self.dev_code = lp.generate_code_v2(t_unit).device_code() + return super().transform_loopy_program(t_unit) def main(): - order = 4 + order = 1 ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) actx = PytatoTensorProductArrayContext(queue) dim = 3 - res = 5 + res = 2 from meshmode.mesh import TensorProductElementGroup from meshmode.discretization.poly_element import \ @@ -49,7 +75,7 @@ def main(): group_cls=TensorProductElementGroup) import grudge.dof_desc as dd - dcoll = DiscretizationCollection( + dcoll = make_discretization_collection( actx, mesh, discr_tag_to_group_factory={ @@ -69,12 +95,7 @@ def f(x): grad_u = op.local_grad(dcoll, u) grad_u = actx.np.stack(grad_u)[0] - - prg = pt.generate_loopy(grad_u).program - code = lp.generate_code_v2(prg).device_code() - - print(code) - pu.db + pt.show_dot_graph(grad_u) if __name__ == "__main__": main() diff --git a/grudge/op.py b/grudge/op.py index b60ab62d7..38df02418 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -242,13 +242,15 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): # apply geometric factors # TODO: chain the einsum above with the einsum below + from arraycontext.metadata import NameHint grad = actx.np.stack([grad[i] for i in range(dim)]) grad = actx.einsum( "xrei,xei->xei", ijm, grad, arg_names=("inv_jac_t", "vec"), - tagged=(FirstAxisIsElementsTag(),)) + tagged=(FirstAxisIsElementsTag(), + NameHint("tp_gradient"),)) return grad @@ -383,7 +385,14 @@ def get_ref_derivative_mats(grp): vdm_1d = mp.vandermonde(bases_1d.functions, nodes1d) vdm_p_1d = mp.vandermonde(bases_1d.gradients, nodes1d)[0] - return actx.freeze(actx.from_numpy(vdm_p_1d @ la.inv(vdm_1d))) + diff_mat = actx.from_numpy(vdm_p_1d @ la.inv(vdm_1d)) + + from arraycontext.metadata import NameHint + return actx.freeze( + actx.tag(NameHint("tp_diff_mat_1d"), + tag_axes(actx, { + 1: DiscretizationDOFAxisTag()}, + diff_mat))) else: from meshmode.discretization.poly_element import diff_matrices From 92991a3d0524021333d2d4089a50e53df39340e9 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 30 Sep 2023 13:28:43 -0500 Subject: [PATCH 25/66] Version 0.1 of weak gradient computation --- grudge/op.py | 158 +++++++++++++++++++++++++++++++++++++----------- test/test_op.py | 4 +- 2 files changed, 126 insertions(+), 36 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 38df02418..062c88101 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -207,12 +207,11 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. - - def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): + def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, + metric_in_matvec): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - from modepy.tools import ( reshape_array_for_tensor_product_space, unreshape_array_for_tensor_product_space) @@ -223,41 +222,122 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm): # apply operators to function data dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) - grad = make_obj_array([ - actx.einsum( - f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", - diff_mat, - vec, - arg_names=("diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for i in range(dim) + + # weak form case: + # 3D weak_x: einsum("estu,ps,qt,ru->epqr", + # f, stiff_1D, mass_1D, mass_1D) + if metric_in_matvec: + stiff_1D, mass_1D = diff_mat + + if dim == 3: + weak_x = actx.einsum( + "estu,ps,qt,ru->epqr", + vec, + stiff_1D, + mass_1D, + mass_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + weak_y = actx.einsum( + "estu,ps,qt,ru->epqr", + vec, + mass_1D, + stiff_1D, + mass_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + weak_z = actx.einsum( + "estu,ps,qt,ru->epqr", + vec, + mass_1D, + mass_1D, + stiff_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + grad = make_obj_array([ + weak_x, + weak_y, + weak_z + ]) + + elif dim == 2: + weak_x = actx.einsum( + "est,ps,qt->epq", + vec, + stiff_1D, + mass_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + weak_y = actx.einsum( + "est,ps,qt->epq", + vec, + mass_1D, + stiff_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + grad = make_obj_array([ + weak_x, + weak_y + ]) + + # strong form case: + # x partial: einsum("il,eljk->eijk", D, f) + else: + grad = make_obj_array([ + actx.einsum( + f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", + diff_mat, + vec, + arg_names=("diff_mat", "vec"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + for i in range(dim) ]) # unreshape grad to apply geometric factors grad = make_obj_array([ unreshape_array_for_tensor_product_space(grp.space, grad[i]) for i in range(grp.dim) - ]) + ]) - # apply geometric factors - # TODO: chain the einsum above with the einsum below + # apply geometric factors in strong case from arraycontext.metadata import NameHint - grad = actx.np.stack([grad[i] for i in range(dim)]) - grad = actx.einsum( - "xrei,xei->xei", - ijm, - grad, - arg_names=("inv_jac_t", "vec"), - tagged=(FirstAxisIsElementsTag(), - NameHint("tp_gradient"),)) + if metric_in_matvec: + grad = make_obj_array([ + actx.einsum( + "rei,ei->ei", + ijm[i], + grad[i], + arg_names=("inv_jac_t", "vec"), + tagged=FirstAxisIsElementsTag()) + for i in range(dim) + ]) + else: + grad = actx.np.stack([grad[i] for i in range(dim)]) + grad = actx.einsum( + "xrei,xei->xei", + ijm, + grad, + arg_names=("inv_jac_t", "vec"), + tagged=(FirstAxisIsElementsTag(), + NameHint("tp_gradient"),)) return grad - per_group_grads = [ - compute_tensor_product_grad(actx, in_grp, get_diff_mat, vec_i, ijm_i) + compute_tensor_product_grad(actx, in_grp, get_diff_mat, vec_i, ijm_i, + metric_in_matvec) if isinstance(in_grp, TensorProductElementGroupBase) # r for rst axis @@ -289,7 +369,6 @@ def _divergence_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. - def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) @@ -571,6 +650,8 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): from meshmode.discretization.poly_element import \ mass_matrix, diff_matrices + # {{{ tensor product case + if isinstance(out_grp, TensorProductElementGroupBase): import modepy as mp import numpy.linalg as la @@ -581,15 +662,24 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): vdm = mp.vandermonde(basis_1d.functions, nodes_1d) vdm_p = mp.vandermonde(basis_1d.gradients, nodes_1d)[0] - # NOTE: possibly work special-case matrices like differentiation - # matrix, mass matrix, into modepy - mmat = la.inv(vdm @ vdm.T) - diff_mat = vdm_p @ la.inv(vdm) - return actx.freeze( + mass_1D = la.inv(vdm @ vdm.T) + diff_mat = la.solve(vdm.T, vdm_p.T).T + + stiff_1D = actx.freeze( actx.tag_axis(1, DiscretizationDOFAxisTag(), - actx.from_numpy( - np.asarray( - diff_mat.T @ mmat.T)))) + actx.from_numpy( + np.asarray( + diff_mat.T @ mass_1D.T)))) + + mass_1D = actx.freeze( + actx.tag_axis(1, DiscretizationDOFAxisTag(), + actx.from_numpy( + np.asarray( + mass_1D)))) + + return (stiff_1D, mass_1D) + + # }}} mmat = mass_matrix(out_grp) diff --git a/test/test_op.py b/test/test_op.py index 05f6c1f99..7d26280da 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -161,7 +161,7 @@ def get_flux(u_tpair): or eoc_rec.max_error() < 1e-11) -@pytest.mark.parametrize("form", ["strong"]) +@pytest.mark.parametrize("form", ["weak"]) @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ @@ -308,7 +308,7 @@ def get_flux(u_tpair): (True, True) ]) def test_divergence(actx_factory, form, dim, order, vectorize, nested, - visualize=False): + visualize=False): from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() From 49116ab388a4245857066fc5e9d078bc1f164b96 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 30 Sep 2023 15:27:44 -0500 Subject: [PATCH 26/66] Weak form divergence version 0.1 --- grudge/op.py | 141 +++++++++++++++++++++++++++++++++++++++--------- test/test_op.py | 6 ++- 2 files changed, 120 insertions(+), 27 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 062c88101..8539f305a 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -381,37 +381,128 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): # reshape u to expose tensor product structure vec = reshape_array_for_tensor_product_space(grp.space, vec) - # apply differentiation matrix to function data dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) - partials = make_obj_array([ - actx.einsum( - f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", - diff_mat, - vec[i], - arg_names=("diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + + # weak form + if metric_in_matvec: + stiff_1D, mass_1D = diff_mat + + if dim == 3: + weak_x = actx.einsum( + "estu,ps,qt,ru->epqr", + vec[0], + stiff_1D, + mass_1D, + mass_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + weak_y = actx.einsum( + "estu,ps,qt,ru->epqr", + vec[1], + mass_1D, + stiff_1D, + mass_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + weak_z = actx.einsum( + "estu,ps,qt,ru->epqr", + vec[2], + mass_1D, + mass_1D, + stiff_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + partials = make_obj_array([ + weak_x, weak_y, weak_z + ]) + + elif dim == 2: + weak_x = actx.einsum( + "est,ps,qt->epq", + vec[0], + stiff_1D, + mass_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + weak_y = actx.einsum( + "est,ps,qt->epq", + vec[1], + mass_1D, + stiff_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + + partials = make_obj_array([ + weak_x, weak_y + ]) + + else: + raise Exception("Dimensions of 2 and 3 are supported by " + "tensor product elements. Found dim = {dim}") + + + partials = make_obj_array([ + unreshape_array_for_tensor_product_space(grp.space, partials[i]) for i in range(dim) - ]) + ]) - # unreshape partials to apply geometric factors - # TODO: chain the einsum above with the einsum below - partials = make_obj_array([ - unreshape_array_for_tensor_product_space(grp.space, partials[i]) - for i in range(partials.shape[0]) - ]) + partials = actx.np.stack(partials) - # apply geometric factors - partials = actx.np.stack([partials[i] for i in range(dim)]) - div = actx.einsum( - "xrei,xei->ei", - ijm, - partials, - arg_names=("inv_jac_t", "vec"), - tagged=(FirstAxisIsElementsTag(),)) + div = make_obj_array([ + actx.einsum("rei,ei->ei", + ijm[i], + partials[i], + arg_names=("inv_jac_t", "vec"), + tagged=(FirstAxisIsElementsTag(),)) + for i in range(dim) + ]) + + ret = 0 + for i in range(dim): + ret += div[i] + return ret + + # strong form + else: + partials = make_obj_array([ + actx.einsum( + f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", + diff_mat, + vec[i], + arg_names=("diff_mat", "vec"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + for i in range(dim) + ]) + + # unreshape partials to apply geometric factors + # TODO: chain the einsum above with the einsum below + partials = make_obj_array([ + unreshape_array_for_tensor_product_space(grp.space, partials[i]) + for i in range(partials.shape[0]) + ]) + + # apply geometric factors + partials = actx.np.stack([partials[i] for i in range(dim)]) + + div = actx.einsum( + "xrei,xei->ei", + ijm, + partials, + arg_names=("inv_jac_t", "vec"), + tagged=(FirstAxisIsElementsTag(),)) - return div + return div per_group_divs = [ diff --git a/test/test_op.py b/test/test_op.py index 7d26280da..a5d731533 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -161,7 +161,7 @@ def get_flux(u_tpair): or eoc_rec.max_error() < 1e-11) -@pytest.mark.parametrize("form", ["weak"]) +@pytest.mark.parametrize("form", ["strong", "weak"]) @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ @@ -309,6 +309,8 @@ def get_flux(u_tpair): ]) def test_divergence(actx_factory, form, dim, order, vectorize, nested, visualize=False): + actx = actx_factory() + from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() @@ -409,7 +411,7 @@ def get_flux(u_tpair): or eoc_rec.max_error() < 1e-11) -@pytest.mark.parametrize("form", ["strong"]) +@pytest.mark.parametrize("form", ["strong", "weak"]) @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ From d87c19fee3af3e3d5958486bb27a8d8d0ee21f68 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 1 Oct 2023 18:28:55 -0500 Subject: [PATCH 27/66] Move TP array contexts. Add acoustic pulse TP example --- .../tensor-product-examples/acoustic_pulse.py | 264 ++++++++++++++++++ examples/tp-transform-cartoon.py | 2 +- grudge/array_context.py | 121 +++++--- 3 files changed, 347 insertions(+), 40 deletions(-) create mode 100644 examples/tensor-product-examples/acoustic_pulse.py diff --git a/examples/tensor-product-examples/acoustic_pulse.py b/examples/tensor-product-examples/acoustic_pulse.py new file mode 100644 index 000000000..13c2194cf --- /dev/null +++ b/examples/tensor-product-examples/acoustic_pulse.py @@ -0,0 +1,264 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +from meshmode.mesh import TensorProductElementGroup +import numpy as np + +import pyopencl as cl +import pyopencl.tools as cl_tools + +from grudge.array_context import ( + PyOpenCLArrayContext, + PytatoPyOpenCLArrayContext +) +from grudge.models.euler import ( + ConservedEulerField, + EulerOperator, + InviscidWallBC +) +from grudge.shortcuts import rk4_step + +from meshmode.mesh import BTAG_ALL + +from pytools.obj_array import make_obj_array + +import grudge.op as op + +import logging +logger = logging.getLogger(__name__) + + +def gaussian_profile( + x_vec, t=0, rho0=1.0, rhoamp=1.0, p0=1.0, gamma=1.4, + center=None, velocity=None): + + dim = len(x_vec) + if center is None: + center = np.zeros(shape=(dim,)) + if velocity is None: + velocity = np.zeros(shape=(dim,)) + + lump_loc = center + t * velocity + + # coordinates relative to lump center + rel_center = make_obj_array( + [x_vec[i] - lump_loc[i] for i in range(dim)] + ) + actx = x_vec[0].array_context + r = actx.np.sqrt(np.dot(rel_center, rel_center)) + expterm = rhoamp * actx.np.exp(1 - r ** 2) + + mass = expterm + rho0 + mom = velocity * mass + energy = (p0 / (gamma - 1.0)) + np.dot(mom, mom) / (2.0 * mass) + + return ConservedEulerField(mass=mass, energy=energy, momentum=mom) + + +def make_pulse(amplitude, r0, w, r): + dim = len(r) + r_0 = np.zeros(dim) + r_0 = r_0 + r0 + rel_center = make_obj_array( + [r[i] - r_0[i] for i in range(dim)] + ) + actx = r[0].array_context + rms2 = w * w + r2 = np.dot(rel_center, rel_center) / rms2 + return amplitude * actx.np.exp(-.5 * r2) + + +def acoustic_pulse_condition(x_vec, t=0): + dim = len(x_vec) + vel = np.zeros(shape=(dim,)) + orig = np.zeros(shape=(dim,)) + uniform_gaussian = gaussian_profile( + x_vec, t=t, center=orig, velocity=vel, rhoamp=0.0) + + amplitude = 1.0 + width = 0.1 + pulse = make_pulse(amplitude, orig, width, x_vec) + + return ConservedEulerField( + mass=uniform_gaussian.mass, + energy=uniform_gaussian.energy + pulse, + momentum=uniform_gaussian.momentum + ) + + +def run_acoustic_pulse(actx, + order=3, + final_time=1, + resolution=4, + overintegration=False, + visualize=False): + + # eos-related parameters + gamma = 1.4 + + # {{{ discretization + + from meshmode.mesh.generation import generate_regular_rect_mesh + + dim = 3 + box_ll = -0.5 + box_ur = 0.5 + mesh = generate_regular_rect_mesh( + a=(box_ll,)*dim, + b=(box_ur,)*dim, + nelements_per_axis=(resolution,)*dim, + group_cls=TensorProductElementGroup) + + from grudge import DiscretizationCollection + from grudge.dof_desc import DISCR_TAG_BASE, DISCR_TAG_QUAD + from meshmode.discretization.poly_element import \ + LegendreGaussLobattoTensorProductGroupFactory as LGL + + exp_name = f"fld-acoustic-pulse-N{order}-K{resolution}" + if overintegration: + exp_name += "-overintegrated" + quad_tag = DISCR_TAG_QUAD + else: + quad_tag = None + + dcoll = DiscretizationCollection( + actx, mesh, + discr_tag_to_group_factory={ + DISCR_TAG_BASE: LGL(order) + } + ) + + # }}} + + # {{{ Euler operator + + euler_operator = EulerOperator( + dcoll, + bdry_conditions={BTAG_ALL: InviscidWallBC()}, + flux_type="lf", + gamma=gamma, + quadrature_tag=quad_tag + ) + + def rhs(t, q): + return euler_operator.operator(t, q) + + compiled_rhs = actx.compile(rhs) + + from grudge.dt_utils import h_min_from_volume + + cfl = 0.125 + cn = 0.5*(order + 1)**2 + dt = cfl * actx.to_numpy(h_min_from_volume(dcoll)) / cn + + fields = acoustic_pulse_condition(actx.thaw(dcoll.nodes())) + + logger.info("Timestep size: %g", dt) + + # }}} + + from grudge.shortcuts import make_visualizer + + vis = make_visualizer(dcoll) + + # {{{ time stepping + + step = 0 + t = 0.0 + while t < final_time: + if step % 10 == 0: + norm_q = actx.to_numpy(op.norm(dcoll, fields, 2)) + logger.info("[%04d] t = %.5f |q| = %.5e", step, t, norm_q) + if visualize: + vis.write_vtk_file( + f"{exp_name}-{step:04d}.vtu", + [ + ("rho", fields.mass), + ("energy", fields.energy), + ("momentum", fields.momentum) + ] + ) + assert norm_q < 5 + + fields = actx.thaw(actx.freeze(fields)) + fields = rk4_step(fields, t, dt, compiled_rhs) + t += dt + step += 1 + + # }}} + + +def main(ctx_factory, order=3, final_time=1, resolution=16, + overintegration=False, visualize=False, lazy=False): + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + if lazy: + from grudge.array_context import PytatoTensorProductArrayContext + actx = PytatoTensorProductArrayContext( + queue, + allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), + ) + else: + from grudge.array_context import TensorProductArrayContext + actx = TensorProductArrayContext( + queue, + allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), + force_device_scalars=True, + ) + + run_acoustic_pulse( + actx, + order=order, + resolution=resolution, + overintegration=overintegration, + final_time=final_time, + visualize=visualize + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--order", default=3, type=int) + parser.add_argument("--tfinal", default=0.1, type=float) + parser.add_argument("--resolution", default=16, type=int) + parser.add_argument("--oi", action="store_true", + help="use overintegration") + parser.add_argument("--visualize", action="store_true", + help="write out vtk output") + parser.add_argument("--lazy", action="store_true", + help="switch to a lazy computation mode") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + main(cl.create_some_context, + order=args.order, + final_time=args.tfinal, + resolution=args.resolution, + overintegration=args.oi, + visualize=args.visualize, + lazy=args.lazy) diff --git a/examples/tp-transform-cartoon.py b/examples/tp-transform-cartoon.py index b0bb13cfd..f26a58737 100644 --- a/examples/tp-transform-cartoon.py +++ b/examples/tp-transform-cartoon.py @@ -62,7 +62,7 @@ def main(): queue = cl.CommandQueue(ctx) actx = PytatoTensorProductArrayContext(queue) - dim = 3 + dim = 2 res = 2 from meshmode.mesh import TensorProductElementGroup diff --git a/grudge/array_context.py b/grudge/array_context.py index f29e8ef5d..4df7df0e2 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -127,45 +127,6 @@ def __init__(self, queue: "pyopencl.CommandQueue", # }}} -# {{{ Tensor product array context - -class OutputIsTensorProductDOFArrayOrdered(Tag): - """Signify that the strides will not be of order "C" or "F". See - :class:`grudge.array_context.TensorProductArrayContext` for more details. - """ - pass - - -class TensorProductArrayContext(_PyOpenCLArrayContextBase): - """Specialized array context for use with tensor product elements. - - The strides for the arrays containing tensor product element data are of the - form (slow, fastest, faster, fast). These strides are not "C" or "F" order. - Hence, this specialized array context takes care of specifying the - particular strides required. - """ - - def transform_loopy_program(self, t_unit): - if len(t_unit.callables_table) == 1: - knl = t_unit.default_entrypoint - if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): - new_args = [] - for arg in knl.args: - if arg.is_output: - arg = arg.copy(dim_tags=( - f"N{len(arg.shape)-1}," - + ",".join(f"N{i}" - for i in range(len(arg.shape)-1)) - )) - - new_args.append(arg) - - knl = knl.copy(args=new_args) - t_unit = t_unit.with_kernel(knl) - - return super().transform_loopy_program(t_unit) -# }}} - # {{{ pytato @@ -631,4 +592,86 @@ def get_reasonable_array_context_class( # }}} +# {{{ Tensor product array context + +# {{{ Relevant tags +class OutputIsTensorProductDOFArrayOrdered(Tag): + """Signify that the strides will not be of order "C" or "F". See + :class:`grudge.array_context.TensorProductArrayContext` for more details. + """ + pass +# }}} + +# {{{ Eager TP array context +class TensorProductArrayContext(_PyOpenCLArrayContextBase): + """Specialized array context for use with tensor product elements. + + The strides for the arrays containing tensor product element data are of the + form (slow, fastest, faster, fast). These strides are not "C" or "F" order. + Hence, this specialized array context takes care of specifying the + particular strides required. + """ + + def transform_loopy_program(self, t_unit): + if len(t_unit.callables_table) == 1: + knl = t_unit.default_entrypoint + if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + for arg in knl.args: + if arg.is_output: + arg = arg.copy(dim_tags=( + f"N{len(arg.shape)-1}," + + ",".join(f"N{i}" + for i in range(len(arg.shape)-1)) + )) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + t_unit = t_unit.with_kernel(knl) + + return super().transform_loopy_program(t_unit) +# }}} + +# {{{ Lazy tensor product array context +class PytatoTensorProductArrayContext(PytatoPyOpenCLArrayContext): + def transform_dag(self, dag): + return super().transform_dag(dag) + + def transform_loopy_program(self, t_unit): + knl = t_unit.default_entrypoint + + # {{{ adjust strides according to tensor product structure + if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + for arg in knl.args: + if arg.is_output: + arg = arg.copy(dim_tags=( + f"N{len(arg.shape)-1}," + + ",".join(f"N{i}" + for i in range(len(arg.shape)-1)) + )) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + # }}} + + # {{{ prefetch + # }}} + + # {{{ tile + # }}} + + import loopy as lp + # FIXME: remove this (eventually) + knl = lp.set_options(knl, insert_gbarriers=True) + t_unit = t_unit.with_kernel(knl) + self.dev_code = lp.generate_code_v2(t_unit).device_code() + + return super().transform_loopy_program(t_unit) +# }}} + +# }}} + # vim: foldmethod=marker From b8e24c4a1329fd91c0d81a5fa13b3731a0862409 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 8 Oct 2023 11:40:47 -0500 Subject: [PATCH 28/66] Start simplifying einsum --- grudge/op.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 8539f305a..0e524f3cd 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -231,7 +231,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, if dim == 3: weak_x = actx.einsum( - "estu,ps,qt,ru->epqr", + "eltu,pl,qt,ru->epqr", vec, stiff_1D, mass_1D, @@ -241,21 +241,21 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, OutputIsTensorProductDOFArrayOrdered())) weak_y = actx.einsum( - "estu,ps,qt,ru->epqr", + "eslu,pl,qs,ru->eqpr", vec, - mass_1D, stiff_1D, mass_1D, + mass_1D, arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) weak_z = actx.einsum( - "estu,ps,qt,ru->epqr", + "estl,pl,qs,rt->eqrp", vec, + stiff_1D, mass_1D, mass_1D, - stiff_1D, arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -268,7 +268,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, elif dim == 2: weak_x = actx.einsum( - "est,ps,qt->epq", + "elt,pl,qt->epq", vec, stiff_1D, mass_1D, @@ -277,10 +277,10 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, OutputIsTensorProductDOFArrayOrdered())) weak_y = actx.einsum( - "est,ps,qt->epq", + "esl,pl,qs->eqp", vec, - mass_1D, stiff_1D, + mass_1D, arg_names=("vec", "stiff_1D_r", "mass_1D_s"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) From 645a5049c7a93901fd2b04b6d59d28fcd5161217 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 3 Nov 2023 00:28:13 -0500 Subject: [PATCH 29/66] Remove unnecessary code --- grudge/array_context.py | 12 ------------ grudge/op.py | 1 - 2 files changed, 13 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index dbba49efa..219cc6a1f 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -655,18 +655,6 @@ def transform_loopy_program(self, t_unit): knl = knl.copy(args=new_args) # }}} - # {{{ prefetch - # }}} - - # {{{ tile - # }}} - - import loopy as lp - # FIXME: remove this (eventually) - knl = lp.set_options(knl, insert_gbarriers=True) - t_unit = t_unit.with_kernel(knl) - self.dev_code = lp.generate_code_v2(t_unit).device_code() - return super().transform_loopy_program(t_unit) # }}} diff --git a/grudge/op.py b/grudge/op.py index 0e524f3cd..d3f9c6ce6 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -450,7 +450,6 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): raise Exception("Dimensions of 2 and 3 are supported by " "tensor product elements. Found dim = {dim}") - partials = make_obj_array([ unreshape_array_for_tensor_product_space(grp.space, partials[i]) for i in range(dim) From 9b70d0b1736ef0cc2ed6e4891ebe24a2f7479976 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 3 Nov 2023 00:30:03 -0500 Subject: [PATCH 30/66] Correct arg names in weak form gradient --- grudge/op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index d3f9c6ce6..089204276 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -405,7 +405,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): mass_1D, stiff_1D, mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + arg_names=("vec", "mass_1D_r", "stiff_1D_s", "mass_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -415,7 +415,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): mass_1D, mass_1D, stiff_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + arg_names=("vec", "mass_1D_r", "mass_1D_s", "stiff_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -438,7 +438,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): vec[1], mass_1D, stiff_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s"), + arg_names=("vec", "mass_1D_r", "stiff_1D_s"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) From ed8aacfbe535d11aea769187c1b3c25bfd840eaf Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 3 Nov 2023 00:31:21 -0500 Subject: [PATCH 31/66] Correct arg names in weak form operator application --- grudge/op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 089204276..874517fc9 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -246,7 +246,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, stiff_1D, mass_1D, mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + arg_names=("vec", "mass_1D_r", "stiff_1D_s", "mass_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -256,7 +256,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, stiff_1D, mass_1D, mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + arg_names=("vec", "mass_1D_r", "mass_1D_s", "stiff_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) From b1c312d6ab4450ac1364256e15f6b2dc9eb0d97a Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 3 Nov 2023 00:32:28 -0500 Subject: [PATCH 32/66] Same as last commit --- grudge/op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grudge/op.py b/grudge/op.py index 874517fc9..97d74632b 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -281,7 +281,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, vec, stiff_1D, mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s"), + arg_names=("vec", "mass_1D_r", "stiff_1D_s"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) From 0cb1830bd86a1a56708ff6c3b5fabde42b3b16c5 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 7 Dec 2023 02:24:18 -0600 Subject: [PATCH 33/66] Offload simplicial grad/div to their own function. Slight refactors. --- grudge/op.py | 151 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 89 insertions(+), 62 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 97d74632b..798e38216 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -84,7 +84,6 @@ from grudge.discretization import DiscretizationCollection from grudge.dof_desc import as_dofdesc from grudge.array_context import ( - TensorProductArrayContext, OutputIsTensorProductDOFArrayOrdered) from pytools import keyed_memoize_in @@ -207,6 +206,8 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. + # {{{ tensor product gradient + def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec): """Exploits tensor product structure to differentiate each coordinate @@ -231,34 +232,34 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, if dim == 3: weak_x = actx.einsum( - "eltu,pl,qt,ru->epqr", - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "eltu,pl,qt,ru->epqr", + vec, + stiff_1D, + mass_1D, + mass_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) weak_y = actx.einsum( - "eslu,pl,qs,ru->eqpr", - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("vec", "mass_1D_r", "stiff_1D_s", "mass_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "eslu,pl,qs,ru->eqpr", + vec, + stiff_1D, + mass_1D, + mass_1D, + arg_names=("vec", "mass_1D_r", "stiff_1D_s", "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) weak_z = actx.einsum( - "estl,pl,qs,rt->eqrp", - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("vec", "mass_1D_r", "mass_1D_s", "stiff_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "estl,pl,qs,rt->eqrp", + vec, + stiff_1D, + mass_1D, + mass_1D, + arg_names=("vec", "mass_1D_r", "mass_1D_s", "stiff_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) grad = make_obj_array([ weak_x, @@ -268,22 +269,22 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, elif dim == 2: weak_x = actx.einsum( - "elt,pl,qt->epq", - vec, - stiff_1D, - mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "elt,pl,qt->epq", + vec, + stiff_1D, + mass_1D, + arg_names=("vec", "stiff_1D_r", "mass_1D_s"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) weak_y = actx.einsum( - "esl,pl,qs->eqp", - vec, - stiff_1D, - mass_1D, - arg_names=("vec", "mass_1D_r", "stiff_1D_s"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "esl,pl,qs->eqp", + vec, + stiff_1D, + mass_1D, + arg_names=("vec", "mass_1D_r", "stiff_1D_s"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) grad = make_obj_array([ weak_x, @@ -294,14 +295,14 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # x partial: einsum("il,eljk->eijk", D, f) else: grad = make_obj_array([ - actx.einsum( - f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", - diff_mat, - vec, - arg_names=("diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for i in range(dim) + actx.einsum( + f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", + diff_mat, + vec, + arg_names=("diff_mat", "vec"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + for i in range(dim) ]) # unreshape grad to apply geometric factors @@ -334,15 +335,14 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, return grad - per_group_grads = [ + # }}} - compute_tensor_product_grad(actx, in_grp, get_diff_mat, vec_i, ijm_i, - metric_in_matvec) - if isinstance(in_grp, TensorProductElementGroupBase) - # r for rst axis - # x for xyz axis - else actx.einsum( + # {{{ simplicial grad + + def compute_simplicial_grad(actx, in_grp, out_grp, get_diff_mat, vec_i, + ijm_i, metric_in_matvec): + return actx.einsum( "xrej,rij,ej->xei" if metric_in_matvec else "xrei,rij,ej->xei", ijm_i, get_diff_mat( @@ -354,9 +354,20 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, arg_names=("inv_jac_t", "ref_stiffT_mat", "vec"), tagged=(FirstAxisIsElementsTag(),)) + # }}} + + + per_group_grads = [ + compute_tensor_product_grad(actx, in_grp, get_diff_mat, vec_i, ijm_i, + metric_in_matvec) + if isinstance(in_grp, TensorProductElementGroupBase) + else compute_simplicial_grad(actx, in_grp, out_grp, get_diff_mat, vec_i, + ijm_i, metric_in_matvec) + for out_grp, in_grp, vec_i, ijm_i in zip( out_discr.groups, in_discr.groups, vec, - inv_jac_mat)] + inv_jac_mat) + ] return make_obj_array([ DOFArray( @@ -369,6 +380,9 @@ def _divergence_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. + + # {{{ tensor product div + def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) @@ -503,15 +517,14 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): return div + # }}} - per_group_divs = [ - compute_tensor_product_div(actx, in_grp, get_diff_mat, vec_i, ijm_i) - if isinstance(in_grp, TensorProductElementGroupBase) + # {{{ simplicial div - # r for rst axis - # x for xyz axis - else actx.einsum( + def compute_simplicial_div(actx, in_grp, out_grp, get_diff_mat, vec_i, + ijm_i, metric_in_matvec): + return actx.einsum( "xrej,rij,xej->ei" if metric_in_matvec else "xrei,rij,xej->ei", ijm_i, get_diff_mat( @@ -523,9 +536,23 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): arg_names=("inv_jac_t", "ref_stiffT_mat", "vec"), tagged=(FirstAxisIsElementsTag(),)) + # }}} + + + per_group_divs = [ + + compute_tensor_product_div(actx, in_grp, get_diff_mat, vec_i, ijm_i) + if isinstance(in_grp, TensorProductElementGroupBase) + + # r for rst axis + # x for xyz axis + else compute_simplicial_div(actx, in_grp, out_grp, get_diff_mat, vec_i, + ijm_i, metric_in_matvec) + for out_grp, in_grp, vec_i, ijm_i in zip( out_discr.groups, in_discr.groups, vec, - inv_jac_mat)] + inv_jac_mat) + ] return DOFArray(actx, data=tuple(per_group_divs)) From 984a9843d248cafe6cda9c8ab417b2acf576c103 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 7 Dec 2023 17:08:42 -0600 Subject: [PATCH 34/66] More refactoring. Add FIXMEs for future changes (another PR probably) --- grudge/op.py | 55 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 798e38216..3b2d5d33b 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -182,20 +182,50 @@ def _single_axis_derivative_kernel( # - whether the chain rule terms ("inv_jac_mat") sit outside (strong) # or inside (weak) the matrix-vector product that carries out the # derivative, cf. "metric_in_matvec". + + + # {{{ tensor product single axis derivative + + def compute_tensor_product_derivative(actx, in_grp, out_grp, + get_diff_mat, vec, ijm, + xyz_axis, metric_in_matvec): + return compute_simplicial_derivative(actx, in_grp, out_grp, + get_diff_mat, vec, ijm, + xyz_axis, metric_in_matvec) + + # }}} + + + # {{{ simplicial single axis derivative + + def compute_simplicial_derivative(actx, in_grp, out_grp, + get_diff_mat, vec_i, ijm_i, + xyz_axis, metric_in_matvec): + # r for rst axis + return actx.einsum( + "rej,rij,ej->ei" if metric_in_matvec else "rei,rij,ej->ei", + ijm_i[xyz_axis], + get_diff_mat( + actx, + out_element_group=out_grp, + in_element_group=in_grp), + vec_i, + arg_names=("inv_jac_t", "ref_stiffT_mat", "vec", ), + tagged=(FirstAxisIsElementsTag(),)) + + # }}} + + return DOFArray( actx, data=tuple( - # r for rst axis - actx.einsum("rej,rij,ej->ei" if metric_in_matvec else "rei,rij,ej->ei", - ijm_i[xyz_axis], - get_diff_mat( - actx, - out_element_group=out_grp, - in_element_group=in_grp), - vec_i, - arg_names=("inv_jac_t", "ref_stiffT_mat", "vec", ), - tagged=(FirstAxisIsElementsTag(),)) - + compute_tensor_product_derivative(actx, in_grp, out_grp, + get_diff_mat, vec_i, ijm_i, + xyz_axis, metric_in_matvec) + if isinstance(in_grp, TensorProductElementGroupBase) + else compute_simplicial_derivative(actx, in_grp, out_grp, + get_diff_mat, vec_i, ijm_i, + xyz_axis, metric_in_matvec) for out_grp, in_grp, vec_i, ijm_i in zip( out_discr.groups, in_discr.groups, vec, inv_jac_mat))) @@ -206,6 +236,7 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. + # {{{ tensor product gradient def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, @@ -575,6 +606,7 @@ def get_ref_derivative_mats(grp): import modepy as mp import numpy.linalg as la + #FIXME: Can be gotten rid of by updating meshmode nodes1d = grp.unit_nodes_1d bases_1d = grp.bases_1d() @@ -773,6 +805,7 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): import modepy as mp import numpy.linalg as la + # FIXME: can be gotten rid of by updating meshmode operators basis_1d = out_grp.bases_1d() nodes_1d = out_grp.unit_nodes_1d From 77d64708736e719a4a65141ce5b8cc258a7cbbdc Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 9 Dec 2023 23:06:03 -0600 Subject: [PATCH 35/66] Chain geometric factors for strong form. Move tp metadata processing into normal actx --- grudge/array_context.py | 109 ++++++++------- grudge/op.py | 68 ++++----- test/test_op.py | 299 ++++++---------------------------------- 3 files changed, 126 insertions(+), 350 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index 219cc6a1f..ab088c457 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -108,8 +108,7 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase): """Inherits from :class:`meshmode.array_context.PyOpenCLArrayContext`. Extends it - to understand :mod:`grudge`-specific transform metadata. (Of which there isn't - any, for now.) + to understand :mod:`grudge`-specific transform metadata. """ def __init__(self, queue: "pyopencl.CommandQueue", allocator: Optional["pyopencl.tools.AllocatorBase"] = None, @@ -124,6 +123,30 @@ def __init__(self, queue: "pyopencl.CommandQueue", super().__init__(queue, allocator, wait_event_queue_length, force_device_scalars) + def transform_loopy_program(self, t_unit): + knl = t_unit.default_entrypoint + + # {{{ process tensor product specific metadata + + if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + for arg in knl.args: + if arg.is_output: + arg = arg.copy(dim_tags=( + f"N{len(arg.shape)-1}," + + ",".join(f"N{i}" + for i in range(len(arg.shape)-1)) + )) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + t_unit = t_unit.with_kernel(knl) + + # }}} + + return super().transform_loopy_program(t_unit) + # }}} @@ -131,8 +154,7 @@ def __init__(self, queue: "pyopencl.CommandQueue", class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase): """Inherits from :class:`meshmode.array_context.PytatoPyOpenCLArrayContext`. - Extends it to understand :mod:`grudge`-specific transform metadata. (Of - which there isn't any, for now.) + Extends it to understand :mod:`grudge`-specific transform metadata. """ def __init__(self, queue, allocator=None, *, @@ -153,6 +175,29 @@ def __init__(self, queue, allocator=None, super().__init__(queue, allocator, compile_trace_callback=compile_trace_callback) + def transform_loopy_program(self, t_unit): + knl = t_unit.default_entrypoint + + # {{{ process tensor product specific metadata + + if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + for arg in knl.args: + if arg.is_output: + arg = arg.copy(dim_tags=( + f"N{len(arg.shape)-1}," + + ",".join(f"N{i}" + for i in range(len(arg.shape)-1)) + )) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + + # }}} + + return super().transform_loopy_program(t_unit) + # }}} @@ -589,73 +634,31 @@ def get_reasonable_array_context_class( # }}} +# {{{ tensor product-specific machinery -# {{{ Tensor product array context - -# {{{ Relevant tags class OutputIsTensorProductDOFArrayOrdered(Tag): """Signify that the strides will not be of order "C" or "F". See :class:`grudge.array_context.TensorProductArrayContext` for more details. - """ - pass -# }}} - -# {{{ Eager TP array context -class TensorProductArrayContext(_PyOpenCLArrayContextBase): - """Specialized array context for use with tensor product elements. The strides for the arrays containing tensor product element data are of the form (slow, fastest, faster, fast). These strides are not "C" or "F" order. Hence, this specialized array context takes care of specifying the particular strides required. """ + pass - def transform_loopy_program(self, t_unit): - if len(t_unit.callables_table) == 1: - knl = t_unit.default_entrypoint - if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): - new_args = [] - for arg in knl.args: - if arg.is_output: - arg = arg.copy(dim_tags=( - f"N{len(arg.shape)-1}," - + ",".join(f"N{i}" - for i in range(len(arg.shape)-1)) - )) - - new_args.append(arg) - - knl = knl.copy(args=new_args) - t_unit = t_unit.with_kernel(knl) +# }}} - return super().transform_loopy_program(t_unit) +# {{{ Eager TP array context +class TensorProductArrayContext(_PyOpenCLArrayContextBase): + """Specialized array context for use with tensor product elements. + """ # }}} # {{{ Lazy tensor product array context class PytatoTensorProductArrayContext(PytatoPyOpenCLArrayContext): def transform_dag(self, dag): return super().transform_dag(dag) - - def transform_loopy_program(self, t_unit): - knl = t_unit.default_entrypoint - - # {{{ adjust strides according to tensor product structure - if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): - new_args = [] - for arg in knl.args: - if arg.is_output: - arg = arg.copy(dim_tags=( - f"N{len(arg.shape)-1}," - + ",".join(f"N{i}" - for i in range(len(arg.shape)-1)) - )) - - new_args.append(arg) - - knl = knl.copy(args=new_args) - # }}} - - return super().transform_loopy_program(t_unit) # }}} # }}} diff --git a/grudge/op.py b/grudge/op.py index 3b2d5d33b..395c778b0 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -81,6 +81,10 @@ DiscretizationFaceAxisTag) from meshmode.discretization.poly_element import TensorProductElementGroupBase +from modepy.tools import ( + reshape_array_for_tensor_product_space as fold, + unreshape_array_for_tensor_product_space as unfold) + from grudge.discretization import DiscretizationCollection from grudge.dof_desc import as_dofdesc from grudge.array_context import ( @@ -244,12 +248,9 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, """Exploits tensor product structure to differentiate each coordinate axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - from modepy.tools import ( - reshape_array_for_tensor_product_space, - unreshape_array_for_tensor_product_space) # reshape u to expose tensor product structure - vec = reshape_array_for_tensor_product_space(grp.space, vec) + vec = fold(grp.space, vec) # apply operators to function data dim = grp.dim @@ -325,12 +326,17 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # strong form case: # x partial: einsum("il,eljk->eijk", D, f) else: + inv_jac_mat_tp = fold(grp.space, inv_jac_mat[0]) grad = make_obj_array([ actx.einsum( - f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", + f"re{'kl'[:i]}i{'mn'[:dim-i-1]}," + + "ij," + + f"e{'kl'[:i]}j{'mn'[:dim-i-1]}->" + + f"e{'kl'[:i]}i{'mn'[:dim-i-1]}", + inv_jac_mat_tp[i], diff_mat, vec, - arg_names=("diff_mat", "vec"), + arg_names=("inv_jac_mat", "diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) for i in range(dim) @@ -338,12 +344,11 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # unreshape grad to apply geometric factors grad = make_obj_array([ - unreshape_array_for_tensor_product_space(grp.space, grad[i]) + unfold(grp.space, grad[i]) for i in range(grp.dim) ]) # apply geometric factors in strong case - from arraycontext.metadata import NameHint if metric_in_matvec: grad = make_obj_array([ actx.einsum( @@ -354,15 +359,6 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, tagged=FirstAxisIsElementsTag()) for i in range(dim) ]) - else: - grad = actx.np.stack([grad[i] for i in range(dim)]) - grad = actx.einsum( - "xrei,xei->xei", - ijm, - grad, - arg_names=("inv_jac_t", "vec"), - tagged=(FirstAxisIsElementsTag(), - NameHint("tp_gradient"),)) return grad @@ -419,12 +415,8 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) """ - from modepy.tools import ( - reshape_array_for_tensor_product_space, - unreshape_array_for_tensor_product_space) - # reshape u to expose tensor product structure - vec = reshape_array_for_tensor_product_space(grp.space, vec) + vec = fold(grp.space, vec) dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) @@ -496,7 +488,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): "tensor product elements. Found dim = {dim}") partials = make_obj_array([ - unreshape_array_for_tensor_product_space(grp.space, partials[i]) + unfold(grp.space, partials[i]) for i in range(dim) ]) @@ -518,32 +510,32 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): # strong form else: + inv_jac_mat_tp = fold(grp.space, inv_jac_mat[0]) partials = make_obj_array([ actx.einsum( - f"ij,e{'kl'[:i]}j{'mn'[:dim-i-1]}->e{'kl'[:i]}i{'mn'[:dim-i-1]}", - diff_mat, - vec[i], - arg_names=("diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for i in range(dim) + f"re{'kl'[:i]}i{'mn'[:dim-i-1]}," + + "ij," + + f"e{'kl'[:i]}j{'mn'[:dim-i-1]}->" + + f"e{'kl'[:i]}i{'mn'[:dim-i-1]}", + inv_jac_mat_tp[i], + diff_mat, + vec[i], + arg_names=("inv_jac_mat", "diff_mat", "vec"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + for i in range(dim) ]) - # unreshape partials to apply geometric factors - # TODO: chain the einsum above with the einsum below partials = make_obj_array([ - unreshape_array_for_tensor_product_space(grp.space, partials[i]) + unfold(grp.space, partials[i]) for i in range(partials.shape[0]) ]) - # apply geometric factors partials = actx.np.stack([partials[i] for i in range(dim)]) - div = actx.einsum( - "xrei,xei->ei", - ijm, + "xei->ei", partials, - arg_names=("inv_jac_t", "vec"), + arg_names=("inv_jac_t",), tagged=(FirstAxisIsElementsTag(),)) return div diff --git a/test/test_op.py b/test/test_op.py index dfb361b4f..1c07176d1 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -45,6 +45,7 @@ # {{{ gradient +@pytest.mark.parametrize("discr_type", ["simplicial", "tensor product"]) @pytest.mark.parametrize("form", ["strong", "weak"]) @pytest.mark.parametrize("dim", [1, 2, 3]) @pytest.mark.parametrize("order", [2, 3]) @@ -53,7 +54,7 @@ (True, False), (True, True) ]) -def test_gradient(actx_factory, form, dim, order, vectorize, nested, +def test_gradient(actx_factory, discr_type, form, dim, order, vectorize, nested, visualize=False): actx = actx_factory() @@ -61,145 +62,32 @@ def test_gradient(actx_factory, form, dim, order, vectorize, nested, eoc_rec = EOCRecorder() for n in [4, 6, 8]: - mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, b=(1,)*dim, - nelements_per_axis=(n,)*dim) - - dcoll = DiscretizationCollection(actx, mesh, order=order) - - def f(x): - result = dcoll.zeros(actx) + 1 - for i in range(dim-1): - result = result * actx.np.sin(np.pi*x[i]) - result = result * actx.np.cos(np.pi/2*x[dim-1]) - return result - - def grad_f(x): - result = make_obj_array([dcoll.zeros(actx) + 1 for _ in range(dim)]) - for i in range(dim-1): - for j in range(i): - result[i] = result[i] * actx.np.sin(np.pi*x[j]) - result[i] = result[i] * np.pi*actx.np.cos(np.pi*x[i]) - for j in range(i+1, dim-1): - result[i] = result[i] * actx.np.sin(np.pi*x[j]) - result[i] = result[i] * actx.np.cos(np.pi/2*x[dim-1]) - for j in range(dim-1): - result[dim-1] = result[dim-1] * actx.np.sin(np.pi*x[j]) - result[dim-1] = result[dim-1] * (-np.pi/2*actx.np.sin(np.pi/2*x[dim-1])) - return result - - x = actx.thaw(dcoll.nodes()) - - if vectorize: - u = make_obj_array([(i+1)*f(x) for i in range(dim)]) - else: - u = f(x) - - def get_flux(u_tpair): - dd = u_tpair.dd - dd_allfaces = dd.with_dtag("all_faces") - normal = geo.normal(actx, dcoll, dd) - u_avg = u_tpair.avg - if vectorize: - if nested: - flux = make_obj_array([u_avg_i * normal for u_avg_i in u_avg]) - else: - flux = np.outer(u_avg, normal) - else: - flux = u_avg * normal - return op.project(dcoll, dd, dd_allfaces, flux) - - dd_allfaces = DOFDesc("all_faces") - - if form == "strong": - grad_u = ( - op.local_grad(dcoll, u, nested=nested) - # No flux terms because u doesn't have inter-el jumps - ) - elif form == "weak": - grad_u = op.inverse_mass(dcoll, - -op.weak_local_grad(dcoll, u, nested=nested) # pylint: disable=E1130 - + # noqa: W504 - op.face_mass(dcoll, - dd_allfaces, - # Note: no boundary flux terms here because u_ext == u_int == 0 - sum(get_flux(utpair) - for utpair in op.interior_trace_pairs(dcoll, u)) - ) - ) - else: - raise ValueError("Invalid form argument.") - - if vectorize: - expected_grad_u = make_obj_array( - [(i+1)*grad_f(x) for i in range(dim)]) - if not nested: - expected_grad_u = np.stack(expected_grad_u, axis=0) - else: - expected_grad_u = grad_f(x) + if discr_type == "tensor product": + # no reason to test 1D tensor product elements + if dim == 1: + return - if visualize: - from grudge.shortcuts import make_visualizer - vis = make_visualizer(dcoll, vis_order=order if dim == 3 else dim+3) + from meshmode.mesh import TensorProductElementGroup + from meshmode.discretization.poly_element import \ + LegendreGaussLobattoTensorProductGroupFactory as LGL - filename = (f"test_gradient_{form}_{dim}_{order}" - f"{'_vec' if vectorize else ''}{'_nested' if nested else ''}.vtu") - vis.write_vtk_file(filename, [ - ("u", u), - ("grad_u", grad_u), - ("expected_grad_u", expected_grad_u), - ], overwrite=True) - - rel_linf_err = actx.to_numpy( - op.norm(dcoll, grad_u - expected_grad_u, np.inf) - / op.norm(dcoll, expected_grad_u, np.inf)) - eoc_rec.add_data_point(1./n, rel_linf_err) - - print("L^inf error:") - print(eoc_rec) - assert (eoc_rec.order_estimate() >= order - 0.5 - or eoc_rec.max_error() < 1e-11) - - -@pytest.mark.parametrize("form", ["strong", "weak"]) -@pytest.mark.parametrize("dim", [2, 3]) -@pytest.mark.parametrize("order", [2, 3]) -@pytest.mark.parametrize(("vectorize", "nested"), [ - (False, False), - (True, False), - (True, True) - ]) -def test_tensor_product_gradient(form, dim, order, vectorize, - nested, visualize=False): - """A "one-dimensional tensor product element" does not make sense, so the - one-dimensional case is excluded from this test. - """ - - import pyopencl as cl - from grudge.array_context import TensorProductArrayContext - - ctx = cl.create_some_context() - queue = cl.CommandQueue(ctx) - actx = TensorProductArrayContext(queue) - - from pytools.convergence import EOCRecorder - eoc_rec = EOCRecorder() - - from meshmode.mesh import TensorProductElementGroup - from meshmode.discretization.poly_element import \ - LegendreGaussLobattoTensorProductGroupFactory as LGL - for n in [4, 6, 8]: - mesh = mgen.generate_regular_rect_mesh( + mesh = mgen.generate_regular_rect_mesh( a=(-1,)*dim, b=(1,)*dim, nelements_per_axis=(n,)*dim, group_cls=TensorProductElementGroup) - import grudge.dof_desc as dd - dcoll = DiscretizationCollection( + import grudge.dof_desc as dd + dcoll = DiscretizationCollection( actx, mesh, discr_tag_to_group_factory={ dd.DISCR_TAG_BASE: LGL(order)}) + else: + mesh = mgen.generate_regular_rect_mesh( + a=(-1,)*dim, b=(1,)*dim, + nelements_per_axis=(n,)*dim) + + dcoll = DiscretizationCollection(actx, mesh, order=order) def f(x): result = dcoll.zeros(actx) + 1 @@ -232,7 +120,7 @@ def grad_f(x): def get_flux(u_tpair): dd = u_tpair.dd dd_allfaces = dd.with_dtag("all_faces") - normal = actx.thaw(dcoll.normal(dd)) + normal = geo.normal(actx, dcoll, dd) u_avg = u_tpair.avg if vectorize: if nested: @@ -299,6 +187,7 @@ def get_flux(u_tpair): # {{{ divergence +@pytest.mark.parametrize("discr_type", ["simplicial", "tensor_product"]) @pytest.mark.parametrize("form", ["strong", "weak"]) @pytest.mark.parametrize("dim", [1, 2, 3]) @pytest.mark.parametrize("order", [2, 3]) @@ -307,7 +196,7 @@ def get_flux(u_tpair): (True, False), (True, True) ]) -def test_divergence(actx_factory, form, dim, order, vectorize, nested, +def test_divergence(actx_factory, discr_type, form, dim, order, vectorize, nested, visualize=False): actx = actx_factory() @@ -315,141 +204,32 @@ def test_divergence(actx_factory, form, dim, order, vectorize, nested, eoc_rec = EOCRecorder() for n in [4, 6, 8]: - mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, b=(1,)*dim, - nelements_per_axis=(n,)*dim) - - dcoll = DiscretizationCollection(actx, mesh, order=order) - - def f(x): - result = make_obj_array([dcoll.zeros(actx) + (i+1) for i in range(dim)]) - for i in range(dim-1): - result = result * actx.np.sin(np.pi*x[i]) - result = result * actx.np.cos(np.pi/2*x[dim-1]) - return result - - def div_f(x): - result = dcoll.zeros(actx) - for i in range(dim-1): - deriv = dcoll.zeros(actx) + (i+1) - for j in range(i): - deriv = deriv * actx.np.sin(np.pi*x[j]) - deriv = deriv * np.pi*actx.np.cos(np.pi*x[i]) - for j in range(i+1, dim-1): - deriv = deriv * actx.np.sin(np.pi*x[j]) - deriv = deriv * actx.np.cos(np.pi/2*x[dim-1]) - result = result + deriv - deriv = dcoll.zeros(actx) + dim - for j in range(dim-1): - deriv = deriv * actx.np.sin(np.pi*x[j]) - deriv = deriv * (-np.pi/2*actx.np.sin(np.pi/2*x[dim-1])) - result = result + deriv - return result - - x = actx.thaw(dcoll.nodes()) - - if vectorize: - u = make_obj_array([(i+1)*f(x) for i in range(dim)]) - if not nested: - u = np.stack(u, axis=0) - else: - u = f(x) - - def get_flux(u_tpair): - dd = u_tpair.dd - dd_allfaces = dd.with_dtag("all_faces") - normal = geo.normal(actx, dcoll, dd) - flux = u_tpair.avg @ normal - return op.project(dcoll, dd, dd_allfaces, flux) - - dd_allfaces = DOFDesc("all_faces") - - if form == "strong": - div_u = ( - op.local_div(dcoll, u) - # No flux terms because u doesn't have inter-el jumps - ) - elif form == "weak": - div_u = op.inverse_mass(dcoll, - -op.weak_local_div(dcoll, u) - + # noqa: W504 - op.face_mass(dcoll, - dd_allfaces, - # Note: no boundary flux terms here because u_ext == u_int == 0 - sum(get_flux(utpair) - for utpair in op.interior_trace_pairs(dcoll, u)) - ) - ) - else: - raise ValueError("Invalid form argument.") - - if vectorize: - expected_div_u = make_obj_array([(i+1)*div_f(x) for i in range(dim)]) - else: - expected_div_u = div_f(x) + if discr_type == "tensor product": + # no reason to test 1D tensor product elements + if dim == 1: + return - if visualize: - from grudge.shortcuts import make_visualizer - vis = make_visualizer(dcoll, vis_order=order if dim == 3 else dim+3) - - filename = (f"test_divergence_{form}_{dim}_{order}" - f"{'_vec' if vectorize else ''}{'_nested' if nested else ''}.vtu") - vis.write_vtk_file(filename, [ - ("u", u), - ("div_u", div_u), - ("expected_div_u", expected_div_u), - ], overwrite=True) + from meshmode.mesh import TensorProductElementGroup + from meshmode.discretization.poly_element import \ + LegendreGaussLobattoTensorProductGroupFactory as LGL - rel_linf_err = actx.to_numpy( - op.norm(dcoll, div_u - expected_div_u, np.inf) - / op.norm(dcoll, expected_div_u, np.inf)) - eoc_rec.add_data_point(1./n, rel_linf_err) - - print("L^inf error:") - print(eoc_rec) - assert (eoc_rec.order_estimate() >= order - 0.5 - or eoc_rec.max_error() < 1e-11) - - -@pytest.mark.parametrize("form", ["strong", "weak"]) -@pytest.mark.parametrize("dim", [2, 3]) -@pytest.mark.parametrize("order", [2, 3]) -@pytest.mark.parametrize(("vectorize", "nested"), [ - (False, False), - (True, False), - (True, True) - ]) -def test_tensor_product_divergence(form, dim, order, vectorize, - nested, visualize=False): - """A "one-dimensional tensor product element" does not make sense, so the - one-dimensional case is excluded from this test. - """ - import pyopencl as cl - from grudge.array_context import TensorProductArrayContext - - ctx = cl.create_some_context() - queue = cl.CommandQueue(ctx) - actx = TensorProductArrayContext(queue) - - from pytools.convergence import EOCRecorder - eoc_rec = EOCRecorder() - - from meshmode.mesh import TensorProductElementGroup - from meshmode.discretization.poly_element import \ - LegendreGaussLobattoTensorProductGroupFactory as LGL - for n in [4, 6, 8]: - mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, - b=(1,)*dim, + mesh = mgen.generate_regular_rect_mesh( + a=(-1,)*dim, b=(1,)*dim, nelements_per_axis=(n,)*dim, group_cls=TensorProductElementGroup) - import grudge.dof_desc as dd - dcoll = make_discretization_collection( + import grudge.dof_desc as dd + dcoll = DiscretizationCollection( actx, mesh, discr_tag_to_group_factory={ dd.DISCR_TAG_BASE: LGL(order)}) + else: + mesh = mgen.generate_regular_rect_mesh( + a=(-1,)*dim, b=(1,)*dim, + nelements_per_axis=(n,)*dim) + + dcoll = DiscretizationCollection(actx, mesh, order=order) def f(x): result = make_obj_array([dcoll.zeros(actx) + (i+1) for i in range(dim)]) @@ -488,7 +268,7 @@ def div_f(x): def get_flux(u_tpair): dd = u_tpair.dd dd_allfaces = dd.with_dtag("all_faces") - normal = actx.thaw(dcoll.normal(dd)) + normal = geo.normal(actx, dcoll, dd) flux = u_tpair.avg @ normal return op.project(dcoll, dd, dd_allfaces, flux) @@ -539,6 +319,7 @@ def get_flux(u_tpair): print(eoc_rec) assert (eoc_rec.order_estimate() >= order - 0.5 or eoc_rec.max_error() < 1e-11) + # }}} From c3f7543663746b394b59fb979b8e1f41afa1436b Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 10 Dec 2023 21:15:21 -0600 Subject: [PATCH 36/66] Add chained einsums to grad and div. Uncover some new bugs --- grudge/op.py | 256 ++++++++++++++++++++++++++++----------------------- 1 file changed, 140 insertions(+), 116 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 395c778b0..6978506f8 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -190,12 +190,13 @@ def _single_axis_derivative_kernel( # {{{ tensor product single axis derivative - def compute_tensor_product_derivative(actx, in_grp, out_grp, - get_diff_mat, vec, ijm, + # FIXME: actually implement single axis tensor product derivatives + def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, xyz_axis, metric_in_matvec): - return compute_simplicial_derivative(actx, in_grp, out_grp, - get_diff_mat, vec, ijm, - xyz_axis, metric_in_matvec) + + + return compute_simplicial_derivative(actx, grp, grp, get_diff_mat, vec, + ijm, xyz_axis, metric_in_matvec) # }}} @@ -245,12 +246,35 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec): - """Exploits tensor product structure to differentiate each coordinate - axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) + """ + Exploits tensor product structure to carry out differentiation with a + differentiation operator containing only 1D information. For example, + in the 2D strong form case, this computes partial derivatives in a + similar manner to + + .. math:: + + \partial_x \mathbf{f}_{ij} = + \sum_{\ell,j}^n \mathbf{J}^e_{i\ell} \mathbf{D}_{i\ell} + \mathbf{f}_{\ell j} + + where $\mathbf{D}$ is a 1D differentiation operator, $\mathbf{f}$ is a + vector of function data, $\mathbf{J}^e$ is the element Jacobian matrix. + The weak form uses a 1D mass operator and a 1D stiffness operator using + the fact that + + .. math:: + + \mathbf{M}^{2D}_{pq,rs} = \int_{\Omega} \phi_p(x) \phi_q(y) + \phi_r(x) \phi_s(y) d\Omega = \int_{\Omega_x} \phi_p(x) + \phi_r(x) dx \int_{\Omega_y} \phi_q(y) \phi_s(y) dy = + \mathbf{M}^{1D} \otimes \mathbf{M}^{1D} + """ # reshape u to expose tensor product structure vec = fold(grp.space, vec) + inv_jac_mat_tp = fold(grp.space, ijm) # apply operators to function data dim = grp.dim @@ -261,35 +285,40 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # f, stiff_1D, mass_1D, mass_1D) if metric_in_matvec: stiff_1D, mass_1D = diff_mat - if dim == 3: weak_x = actx.einsum( - "eltu,pl,qt,ru->epqr", + "rejbd,ejbd,ij,ab,cd->eiac", + inv_jac_mat_tp[0], vec, stiff_1D, mass_1D, mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), + arg_names=("inv_jac_mat", "vec", "stiff_1D_r", "mass_1D_s", + "mass_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) weak_y = actx.einsum( - "eslu,pl,qs,ru->eqpr", + "rebjd,ebjd,ij,ab,cd->eaic", + inv_jac_mat_tp[1], vec, stiff_1D, mass_1D, mass_1D, - arg_names=("vec", "mass_1D_r", "stiff_1D_s", "mass_1D_t"), + arg_names=("inv_jac_mat", "vec", "mass_1D_r", "stiff_1D_s", + "mass_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) weak_z = actx.einsum( - "estl,pl,qs,rt->eqrp", + "rebdj,ebdj,ij,ab,cd->eaci", + inv_jac_mat_tp[2], vec, stiff_1D, mass_1D, mass_1D, - arg_names=("vec", "mass_1D_r", "mass_1D_s", "stiff_1D_t"), + arg_names=("inv_jac_mat", "vec", "mass_1D_r", "mass_1D_s", + "stiff_1D_t"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -299,22 +328,41 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, weak_z ]) + # FIXME: causes an error: static maximum not found for PwAff ... + # grad = make_obj_array([ + # actx.einsum( + # f"e{'bd'[:i]}j{'bd'[i:]}," + + # f"e{'bd'[:i]}j{'bd'[i:]}," + + # "ij,ab,cd->" + + # f"e{'ac'[:i]}i{'ac'[i:]}", + # vec, + # stiff_1D, + # mass_1D, + # mass_1D, + # arg_names=("inv_jac_mat", "vec", "stiff_1D", "mass_1D", "mass_1D"), + # tagged=(FirstAxisIsElementsTag(), + # OutputIsTensorProductDOFArrayOrdered())) + # for i in range(grp.dim) + # ]) + elif dim == 2: weak_x = actx.einsum( - "elt,pl,qt->epq", + "rejb,ejb,ij,ab->eia", + inv_jac_mat_tp[0], vec, stiff_1D, mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s"), + arg_names=("inv_jac_mat", "vec", "stiff_1D_r", "mass_1D_s"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) weak_y = actx.einsum( - "esl,pl,qs->eqp", + "rebj,ebj,ij,ab->eai", + inv_jac_mat_tp[1], vec, stiff_1D, mass_1D, - arg_names=("vec", "mass_1D_r", "stiff_1D_s"), + arg_names=("inv_jac_mat", "vec", "mass_1D_r", "stiff_1D_s"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -323,16 +371,15 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, weak_y ]) - # strong form case: + # Carries out, e.g., 3D strong form contraction # x partial: einsum("il,eljk->eijk", D, f) else: - inv_jac_mat_tp = fold(grp.space, inv_jac_mat[0]) grad = make_obj_array([ actx.einsum( - f"re{'kl'[:i]}i{'mn'[:dim-i-1]}," + - "ij," + - f"e{'kl'[:i]}j{'mn'[:dim-i-1]}->" + - f"e{'kl'[:i]}i{'mn'[:dim-i-1]}", + f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:dim-i-1]}," + + "yz," + + f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:dim-i-1]}->" + + f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:dim-i-1]}", inv_jac_mat_tp[i], diff_mat, vec, @@ -342,24 +389,12 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, for i in range(dim) ]) - # unreshape grad to apply geometric factors + # unreshape grad grad = make_obj_array([ unfold(grp.space, grad[i]) for i in range(grp.dim) ]) - # apply geometric factors in strong case - if metric_in_matvec: - grad = make_obj_array([ - actx.einsum( - "rei,ei->ei", - ijm[i], - grad[i], - arg_names=("inv_jac_t", "vec"), - tagged=FirstAxisIsElementsTag()) - for i in range(dim) - ]) - return grad # }}} @@ -417,6 +452,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): # reshape u to expose tensor product structure vec = fold(grp.space, vec) + inv_jac_mat_tp = fold(grp.space, ijm[0]) dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) @@ -424,121 +460,109 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): # weak form if metric_in_matvec: stiff_1D, mass_1D = diff_mat - if dim == 3: weak_x = actx.einsum( - "estu,ps,qt,ru->epqr", - vec[0], - stiff_1D, - mass_1D, - mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s", "mass_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "rejbd,ejbd,ij,ab,cd->eiac", + inv_jac_mat_tp[0], + vec, + stiff_1D, + mass_1D, + mass_1D, + arg_names=("inv_jac_mat", "vec", "stiff_1D_r", "mass_1D_s", + "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) weak_y = actx.einsum( - "estu,ps,qt,ru->epqr", - vec[1], - mass_1D, - stiff_1D, - mass_1D, - arg_names=("vec", "mass_1D_r", "stiff_1D_s", "mass_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "rebjd,ebjd,ij,ab,cd->eaic", + inv_jac_mat_tp[1], + vec, + stiff_1D, + mass_1D, + mass_1D, + arg_names=("inv_jac_mat", "vec", "mass_1D_r", "stiff_1D_s", + "mass_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) weak_z = actx.einsum( - "estu,ps,qt,ru->epqr", - vec[2], - mass_1D, - mass_1D, - stiff_1D, - arg_names=("vec", "mass_1D_r", "mass_1D_s", "stiff_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "rebdj,ebdj,ij,ab,cd->eaci", + inv_jac_mat_tp[2], + vec, + stiff_1D, + mass_1D, + mass_1D, + arg_names=("inv_jac_mat", "vec", "mass_1D_r", "mass_1D_s", + "stiff_1D_t"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) partials = make_obj_array([ - weak_x, weak_y, weak_z + weak_x, + weak_y, + weak_z ]) elif dim == 2: weak_x = actx.einsum( - "est,ps,qt->epq", - vec[0], - stiff_1D, - mass_1D, - arg_names=("vec", "stiff_1D_r", "mass_1D_s"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "rejb,ejb,ij,ab->eia", + inv_jac_mat_tp[0], + vec, + stiff_1D, + mass_1D, + arg_names=("inv_jac_mat", "vec", "stiff_1D_r", "mass_1D_s"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) weak_y = actx.einsum( - "est,ps,qt->epq", - vec[1], - mass_1D, - stiff_1D, - arg_names=("vec", "mass_1D_r", "stiff_1D_s"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + "rebj,ebj,ij,ab->eai", + inv_jac_mat_tp[1], + vec, + stiff_1D, + mass_1D, + arg_names=("inv_jac_mat", "vec", "mass_1D_r", "stiff_1D_s"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) partials = make_obj_array([ - weak_x, weak_y + weak_x, + weak_y ]) else: raise Exception("Dimensions of 2 and 3 are supported by " "tensor product elements. Found dim = {dim}") - partials = make_obj_array([ - unfold(grp.space, partials[i]) - for i in range(dim) - ]) - - partials = actx.np.stack(partials) - - div = make_obj_array([ - actx.einsum("rei,ei->ei", - ijm[i], - partials[i], - arg_names=("inv_jac_t", "vec"), - tagged=(FirstAxisIsElementsTag(),)) - for i in range(dim) - ]) - - ret = 0 - for i in range(dim): - ret += div[i] - return ret - # strong form else: - inv_jac_mat_tp = fold(grp.space, inv_jac_mat[0]) partials = make_obj_array([ actx.einsum( - f"re{'kl'[:i]}i{'mn'[:dim-i-1]}," + - "ij," + - f"e{'kl'[:i]}j{'mn'[:dim-i-1]}->" + - f"e{'kl'[:i]}i{'mn'[:dim-i-1]}", + f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:dim-i-1]}," + + "yz," + + f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:dim-i-1]}->" + + f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:dim-i-1]}", inv_jac_mat_tp[i], diff_mat, - vec[i], + vec, arg_names=("inv_jac_mat", "diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) for i in range(dim) ]) - partials = make_obj_array([ - unfold(grp.space, partials[i]) - for i in range(partials.shape[0]) - ]) + partials = make_obj_array([ + unfold(grp.space, partials[i]) + for i in range(partials.shape[0]) + ]) - partials = actx.np.stack([partials[i] for i in range(dim)]) - div = actx.einsum( - "xei->ei", - partials, - arg_names=("inv_jac_t",), - tagged=(FirstAxisIsElementsTag(),)) + partials = actx.np.stack([partials[i] for i in range(dim)]) + div = actx.einsum( + "xei->ei", + partials, + arg_names=("partials",), + tagged=(FirstAxisIsElementsTag(),)) - return div + return div # }}} From 06f5738beb1891a452c2946bb33408620202a0ff Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 10 Dec 2023 21:20:04 -0600 Subject: [PATCH 37/66] Small adjustment to buggy code --- grudge/op.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/grudge/op.py b/grudge/op.py index 6978506f8..b992ae02a 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -331,7 +331,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # FIXME: causes an error: static maximum not found for PwAff ... # grad = make_obj_array([ # actx.einsum( - # f"e{'bd'[:i]}j{'bd'[i:]}," + + # f"re{'bd'[:i]}j{'bd'[i:]}," + # f"e{'bd'[:i]}j{'bd'[i:]}," + # "ij,ab,cd->" + # f"e{'ac'[:i]}i{'ac'[i:]}", @@ -503,6 +503,22 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): weak_z ]) + # FIXME: causes an error: static maximum not found for PwAff ... + # grad = make_obj_array([ + # actx.einsum( + # f"re{'bd'[:i]}j{'bd'[i:]}," + + # "ij,ab,cd->" + + # f"e{'ac'[:i]}i{'ac'[i:]}", + # vec, + # stiff_1D, + # mass_1D, + # mass_1D, + # arg_names=("inv_jac_mat", "vec", "stiff_1D", "mass_1D", "mass_1D"), + # tagged=(FirstAxisIsElementsTag(), + # OutputIsTensorProductDOFArrayOrdered())) + # for i in range(grp.dim) + # ]) + elif dim == 2: weak_x = actx.einsum( "rejb,ejb,ij,ab->eia", From e31a8c46e2d6a41b6fcc00e01709dec6f1ff5043 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 10 Dec 2023 21:57:34 -0600 Subject: [PATCH 38/66] Add math explanation and some other small changes --- grudge/op.py | 149 +++++++++++++++++++++++++-------------------------- 1 file changed, 72 insertions(+), 77 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index b992ae02a..77c59ea8f 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -247,37 +247,31 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec): """ - Exploits tensor product structure to carry out differentiation with a - differentiation operator containing only 1D information. For example, - in the 2D strong form case, this computes partial derivatives in a - similar manner to + Exploits tensor product structure to reduce complexity. Applies a + differentiation operator containing 1D information to a tensor of DOF + data. For example, in the 2D strong form case, this computes partial + derivatives in a similar manner to - .. math:: + .. math:: - \partial_x \mathbf{f}_{ij} = - \sum_{\ell,j}^n \mathbf{J}^e_{i\ell} \mathbf{D}_{i\ell} - \mathbf{f}_{\ell j} + \partial_x \mathbf{f}_{ij} = \sum_{\ell} \mathbf{J}^e_{ij} + \mathbf{D}_{i\ell} \mathbf{f}_{\ell j} where $\mathbf{D}$ is a 1D differentiation operator, $\mathbf{f}$ is a vector of function data, $\mathbf{J}^e$ is the element Jacobian matrix. - The weak form uses a 1D mass operator and a 1D stiffness operator using - the fact that + The weak form uses a 1D element mass operator and a 1D element stiffness + operator to perform the contraction - .. math:: + .. math:: - \mathbf{M}^{2D}_{pq,rs} = \int_{\Omega} \phi_p(x) \phi_q(y) - \phi_r(x) \phi_s(y) d\Omega = \int_{\Omega_x} \phi_p(x) - \phi_r(x) dx \int_{\Omega_y} \phi_q(y) \phi_s(y) dy = - \mathbf{M}^{1D} \otimes \mathbf{M}^{1D} + \partial_x \mathbf{f}_{ij} = \sum_{\ell,b} \mathbf{J}^e_{\ell b} + \mathbf{f}_{\ell b} \mathbf{S}^e_{i\ell} \mathbf{M}^e_{jb} """ # reshape u to expose tensor product structure vec = fold(grp.space, vec) inv_jac_mat_tp = fold(grp.space, ijm) - - # apply operators to function data - dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) # weak form case: @@ -285,7 +279,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # f, stiff_1D, mass_1D, mass_1D) if metric_in_matvec: stiff_1D, mass_1D = diff_mat - if dim == 3: + if grp.dim == 3: weak_x = actx.einsum( "rejbd,ejbd,ij,ab,cd->eiac", inv_jac_mat_tp[0], @@ -328,24 +322,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, weak_z ]) - # FIXME: causes an error: static maximum not found for PwAff ... - # grad = make_obj_array([ - # actx.einsum( - # f"re{'bd'[:i]}j{'bd'[i:]}," + - # f"e{'bd'[:i]}j{'bd'[i:]}," + - # "ij,ab,cd->" + - # f"e{'ac'[:i]}i{'ac'[i:]}", - # vec, - # stiff_1D, - # mass_1D, - # mass_1D, - # arg_names=("inv_jac_mat", "vec", "stiff_1D", "mass_1D", "mass_1D"), - # tagged=(FirstAxisIsElementsTag(), - # OutputIsTensorProductDOFArrayOrdered())) - # for i in range(grp.dim) - # ]) - - elif dim == 2: + elif grp.dim == 2: weak_x = actx.einsum( "rejb,ejb,ij,ab->eia", inv_jac_mat_tp[0], @@ -371,22 +348,42 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, weak_y ]) + # FIXME: causes an error: static maximum not found for PwAff ... + # grad = make_obj_array([ + # actx.einsum( + # f"re{'bd'[:i]}j{'bd'[i:]}," + + # f"e{'bd'[:i]}j{'bd'[i:]}," + + # "ij,ab,cd->" + + # f"e{'ac'[:i]}i{'ac'[i:]}", + # vec, + # stiff_1D, + # mass_1D, + # mass_1D, + # arg_names=("inv_jac_mat", "vec", "stiff_1D", "mass_1D", + # "mass_1D"), + # tagged=(FirstAxisIsElementsTag(), + # OutputIsTensorProductDOFArrayOrdered())) + # for i in range(grp.dim) + # ]) + # Carries out, e.g., 3D strong form contraction # x partial: einsum("il,eljk->eijk", D, f) else: + # FIXME: actually test that all of these dimensions work! (dim 2 and + # 3 work) grad = make_obj_array([ actx.einsum( - f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:dim-i-1]}," + + f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + "yz," + - f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:dim-i-1]}->" + - f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:dim-i-1]}", + f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + + f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", inv_jac_mat_tp[i], diff_mat, vec, arg_names=("inv_jac_mat", "diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) - for i in range(dim) + for i in range(grp.dim) ]) # unreshape grad @@ -446,21 +443,20 @@ def _divergence_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec # {{{ tensor product div def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): - """Exploits tensor product structure to differentiate each coordinate - axis using a single differentiation matrix of shape (nnodes1d, nnodes1d) + """ + Exploits tensor product structure to reduce complexity. See + `_gradient_kernel.compute_tensor_product_grad` for more details. """ # reshape u to expose tensor product structure vec = fold(grp.space, vec) inv_jac_mat_tp = fold(grp.space, ijm[0]) - - dim = grp.dim diff_mat = get_diff_mat(actx, grp, grp) # weak form if metric_in_matvec: stiff_1D, mass_1D = diff_mat - if dim == 3: + if grp.dim == 3: weak_x = actx.einsum( "rejbd,ejbd,ij,ab,cd->eiac", inv_jac_mat_tp[0], @@ -503,23 +499,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): weak_z ]) - # FIXME: causes an error: static maximum not found for PwAff ... - # grad = make_obj_array([ - # actx.einsum( - # f"re{'bd'[:i]}j{'bd'[i:]}," + - # "ij,ab,cd->" + - # f"e{'ac'[:i]}i{'ac'[i:]}", - # vec, - # stiff_1D, - # mass_1D, - # mass_1D, - # arg_names=("inv_jac_mat", "vec", "stiff_1D", "mass_1D", "mass_1D"), - # tagged=(FirstAxisIsElementsTag(), - # OutputIsTensorProductDOFArrayOrdered())) - # for i in range(grp.dim) - # ]) - - elif dim == 2: + elif grp.dim == 2: weak_x = actx.einsum( "rejb,ejb,ij,ab->eia", inv_jac_mat_tp[0], @@ -545,38 +525,53 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): weak_y ]) - else: - raise Exception("Dimensions of 2 and 3 are supported by " - "tensor product elements. Found dim = {dim}") + # FIXME: causes an error: static maximum not found for PwAff ... + # partials = make_obj_array([ + # actx.einsum( + # f"re{'bd'[:i]}j{'bd'[i:]}," + + # "ij,ab,cd->" + + # f"e{'ac'[:i]}i{'ac'[i:]}", + # vec, + # stiff_1D, + # mass_1D, + # mass_1D, + # arg_names=("inv_jac_mat", "vec", "stiff_1D", "mass_1D", + # "mass_1D"), + # tagged=(FirstAxisIsElementsTag(), + # OutputIsTensorProductDOFArrayOrdered())) + # for i in range(grp.dim) + # ]) # strong form else: + # FIXME: actually test that all of these dimensions work! (dim 2 and + # 3 work) partials = make_obj_array([ actx.einsum( - f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:dim-i-1]}," + + f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + "yz," + - f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:dim-i-1]}->" + - f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:dim-i-1]}", + f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + + f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", inv_jac_mat_tp[i], diff_mat, vec, arg_names=("inv_jac_mat", "diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for i in range(dim) + OutputIsTensorProductDOFArrayOrdered())) + for i in range(grp.dim) ]) partials = make_obj_array([ unfold(grp.space, partials[i]) - for i in range(partials.shape[0]) + for i in range(grp.dim) ]) - partials = actx.np.stack([partials[i] for i in range(dim)]) + partials = actx.np.stack([partials[i] for i in range(grp.dim)]) div = actx.einsum( - "xei->ei", - partials, - arg_names=("partials",), - tagged=(FirstAxisIsElementsTag(),)) + "xei->ei", + partials, + arg_names=("partials",), + tagged=(FirstAxisIsElementsTag(),)) return div From e293bd015343946c09ff2d0238f657190a57287e Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 11 Dec 2023 00:53:12 -0600 Subject: [PATCH 39/66] Update comment --- grudge/op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/grudge/op.py b/grudge/op.py index 77c59ea8f..61bb4c4bd 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -246,6 +246,8 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec): + # TODO: add note about inverse mass simplification, point to + # op.inverse_mass (assuming this is where the explanation will live) """ Exploits tensor product structure to reduce complexity. Applies a differentiation operator containing 1D information to a tensor of DOF @@ -266,7 +268,6 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, \partial_x \mathbf{f}_{ij} = \sum_{\ell,b} \mathbf{J}^e_{\ell b} \mathbf{f}_{\ell b} \mathbf{S}^e_{i\ell} \mathbf{M}^e_{jb} - """ # reshape u to expose tensor product structure From 0792c0b979c23a98292ffa12cae86b4a0f603686 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 11 Dec 2023 10:41:51 -0600 Subject: [PATCH 40/66] Update requirements.txt to accurately reflect requirements --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index f56f10888..86e4a9237 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,11 +10,11 @@ git+https://github.com/inducer/leap.git#egg=leap git+https://github.com/inducer/meshpy.git#egg=meshpy git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/arraycontext.git#egg=arraycontext -git+https://github.com/a-alveyblanc/meshmode.git@tensor-product-1d-nodes-and-1d-basis#egg=meshmode +git+https://github.com/inducer/meshmode.git#egg=meshmode git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/pymetis.git#egg=pymetis git+https://github.com/illinois-ceesd/logpyle.git#egg=logpyle -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/a-alveyblanc/pytato.git@implement-f-ordered-reshapes # for test_wave_dt_estimate sympy From 58c39bdad209ef7522236c5e308008bed7d227e7 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 11 Dec 2023 11:22:33 -0600 Subject: [PATCH 41/66] Run test_inverse_metric for tensor product meshes --- test/test_metrics.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/test/test_metrics.py b/test/test_metrics.py index 586b093a5..b883a5776 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -23,6 +23,8 @@ THE SOFTWARE. """ +from meshmode.discretization.poly_element import LegendreGaussLobattoTensorProductGroupFactory +from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup import numpy as np from grudge.array_context import ( @@ -51,12 +53,18 @@ @pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("nonaffine", [False, True]) @pytest.mark.parametrize("use_quad", [False, True]) -def test_inverse_metric(actx_factory, dim, nonaffine, use_quad): +@pytest.mark.parametrize("group_cls", [ + SimplexElementGroup, + TensorProductElementGroup +]) +def test_inverse_metric(actx_factory, dim, nonaffine, use_quad, group_cls): actx = actx_factory() order = 3 - mesh = mgen.generate_regular_rect_mesh(a=(-0.5,)*dim, b=(0.5,)*dim, - nelements_per_axis=(6,)*dim, order=order) + mesh = mgen.generate_regular_rect_mesh( + a=(-0.5,)*dim, b=(0.5,)*dim, + nelements_per_axis=(6,)*dim, order=order, + group_cls=group_cls) if nonaffine: def m(x): @@ -79,13 +87,21 @@ def m(x): QuadratureSimplexGroupFactory, \ default_simplex_group_factory - dcoll = DiscretizationCollection( - actx, mesh, - discr_tag_to_group_factory={ + if group_cls is SimplexElementGroup: + discr_tag_to_group_factory = { DISCR_TAG_BASE: default_simplex_group_factory(base_dim=dim, order=order), DISCR_TAG_QUAD: QuadratureSimplexGroupFactory(2*order + 1), } - ) + elif group_cls is TensorProductElementGroup: + discr_tag_to_group_factory = { + DISCR_TAG_BASE: LegendreGaussLobattoTensorProductGroupFactory(order=order), + DISCR_TAG_QUAD: LegendreGaussLobattoTensorProductGroupFactory(order=3*order), + } + else: + raise AssertionError() + + dcoll = DiscretizationCollection( + actx, mesh, discr_tag_to_group_factory=discr_tag_to_group_factory) from grudge.geometry import \ forward_metric_derivative_mat, inverse_metric_derivative_mat From b8b239c9181e367317dcdc0603c8e76d9872a974 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 11 Dec 2023 11:27:41 -0600 Subject: [PATCH 42/66] test_inverse_metric: Rotate mesh in non-affine case --- test/test_metrics.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_metrics.py b/test/test_metrics.py index b883a5776..7ab4b5ffe 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -25,6 +25,7 @@ from meshmode.discretization.poly_element import LegendreGaussLobattoTensorProductGroupFactory from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup +from meshmode.mesh.processing import affine_map import numpy as np from grudge.array_context import ( @@ -81,6 +82,15 @@ def m(x): from meshmode.mesh.processing import map_mesh mesh = map_mesh(mesh, m) + else: + alpha = 0.3 + rot_mat = np.array([ + [np.cos(alpha), np.sin(alpha), 0], + [-np.sin(alpha), np.cos(alpha), 0], + [0, 0, 1], + ])[:dim, :dim] + + mesh = affine_map(mesh, A=rot_mat) from grudge.dof_desc import as_dofdesc, DISCR_TAG_BASE, DISCR_TAG_QUAD from meshmode.discretization.poly_element import \ From be8a26b9c99f42b4107a0f96a73f15a86720bbf2 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 11 Dec 2023 12:11:02 -0600 Subject: [PATCH 43/66] Rotate mesh in test_gradient --- test/test_op.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_op.py b/test/test_op.py index 1c07176d1..139c73866 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -21,6 +21,7 @@ """ +from meshmode.mesh.processing import affine_map import numpy as np import meshmode.mesh.generation as mgen @@ -96,6 +97,15 @@ def f(x): result = result * actx.np.cos(np.pi/2*x[dim-1]) return result + alpha = 0.3 + rot_mat = np.array([ + [np.cos(alpha), np.sin(alpha), 0], + [-np.sin(alpha), np.cos(alpha), 0], + [0, 0, 1], + ])[:dim, :dim] + + mesh = affine_map(mesh, A=rot_mat) + def grad_f(x): result = make_obj_array([dcoll.zeros(actx) + 1 for _ in range(dim)]) for i in range(dim-1): From 3f8a886f3aef76b9d0be9f5d626f6cf4e1b90841 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 11 Dec 2023 23:19:37 -0600 Subject: [PATCH 44/66] Add TP divergence thm test. Make diff. op. einsums more general --- grudge/op.py | 220 ++++++++++---------------------------------- test/test_grudge.py | 69 ++++++++++---- test/test_op.py | 9 +- 3 files changed, 111 insertions(+), 187 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 61bb4c4bd..5f2377810 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -74,7 +74,7 @@ from functools import partial -from meshmode.dof_array import DOFArray +from meshmode.dof_array import DOFArray, warn from meshmode.transform_metadata import (FirstAxisIsElementsTag, DiscretizationDOFAxisTag, DiscretizationElementAxisTag, @@ -270,6 +270,15 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, \mathbf{f}_{\ell b} \mathbf{S}^e_{i\ell} \mathbf{M}^e_{jb} """ + + if grp.dim > 3 and metric_in_matvec: + warn('Efficient tensor product weak ' + 'differentiation operators only ' + 'implemented for dimension 2 and 3. ' + 'Defaulting to inefficient version.') + return compute_simplicial_grad(actx, grp, grp, diff_mat, vec, ijm, + metric_in_matvec) + # reshape u to expose tensor product structure vec = fold(grp.space, vec) inv_jac_mat_tp = fold(grp.space, ijm) @@ -278,100 +287,32 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # weak form case: # 3D weak_x: einsum("estu,ps,qt,ru->epqr", # f, stiff_1D, mass_1D, mass_1D) + # TODO:? make this more general, maybe offload to a function that + # generates argnames and einsum specs if metric_in_matvec: stiff_1D, mass_1D = diff_mat - if grp.dim == 3: - weak_x = actx.einsum( - "rejbd,ejbd,ij,ab,cd->eiac", - inv_jac_mat_tp[0], - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "stiff_1D_r", "mass_1D_s", - "mass_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - - weak_y = actx.einsum( - "rebjd,ebjd,ij,ab,cd->eaic", - inv_jac_mat_tp[1], - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "mass_1D_r", "stiff_1D_s", - "mass_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - - weak_z = actx.einsum( - "rebdj,ebdj,ij,ab,cd->eaci", - inv_jac_mat_tp[2], - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "mass_1D_r", "mass_1D_s", - "stiff_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - - grad = make_obj_array([ - weak_x, - weak_y, - weak_z - ]) - - elif grp.dim == 2: - weak_x = actx.einsum( - "rejb,ejb,ij,ab->eia", - inv_jac_mat_tp[0], - vec, - stiff_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "stiff_1D_r", "mass_1D_s"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - - weak_y = actx.einsum( - "rebj,ebj,ij,ab->eai", - inv_jac_mat_tp[1], + grad = make_obj_array([ + actx.einsum( + f"re{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + "ij," + + ("ab,cd" if grp.dim == 3 else "ab") + + "->" + f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", + inv_jac_mat_tp[i], vec, stiff_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "mass_1D_r", "stiff_1D_s"), + *(mass_1D,)*(grp.dim-1), + arg_names=("inv_jac_mat", "vec", "stiff_1D", + *(("mass_1D_1", "mass_1D_2")[:grp.dim-1])), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) - - grad = make_obj_array([ - weak_x, - weak_y - ]) - - # FIXME: causes an error: static maximum not found for PwAff ... - # grad = make_obj_array([ - # actx.einsum( - # f"re{'bd'[:i]}j{'bd'[i:]}," + - # f"e{'bd'[:i]}j{'bd'[i:]}," + - # "ij,ab,cd->" + - # f"e{'ac'[:i]}i{'ac'[i:]}", - # vec, - # stiff_1D, - # mass_1D, - # mass_1D, - # arg_names=("inv_jac_mat", "vec", "stiff_1D", "mass_1D", - # "mass_1D"), - # tagged=(FirstAxisIsElementsTag(), - # OutputIsTensorProductDOFArrayOrdered())) - # for i in range(grp.dim) - # ]) + for i in range(grp.dim) + ]) # Carries out, e.g., 3D strong form contraction # x partial: einsum("il,eljk->eijk", D, f) else: - # FIXME: actually test that all of these dimensions work! (dim 2 and - # 3 work) grad = make_obj_array([ actx.einsum( f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + @@ -449,104 +390,43 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): `_gradient_kernel.compute_tensor_product_grad` for more details. """ + if grp.dim > 3 and metric_in_matvec: + warn('Efficient tensor product weak ' + 'differentiation operators only ' + 'implemented for dimension 2 and 3. ' + 'Defaulting to inefficient version.') + return compute_simplicial_div(actx, grp, grp, diff_mat, vec, ijm, + metric_in_matvec) + # reshape u to expose tensor product structure vec = fold(grp.space, vec) - inv_jac_mat_tp = fold(grp.space, ijm[0]) + inv_jac_mat_tp = fold(grp.space, ijm) diff_mat = get_diff_mat(actx, grp, grp) # weak form if metric_in_matvec: stiff_1D, mass_1D = diff_mat - if grp.dim == 3: - weak_x = actx.einsum( - "rejbd,ejbd,ij,ab,cd->eiac", - inv_jac_mat_tp[0], - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "stiff_1D_r", "mass_1D_s", - "mass_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - - weak_y = actx.einsum( - "rebjd,ebjd,ij,ab,cd->eaic", - inv_jac_mat_tp[1], - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "mass_1D_r", "stiff_1D_s", - "mass_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - - weak_z = actx.einsum( - "rebdj,ebdj,ij,ab,cd->eaci", - inv_jac_mat_tp[2], - vec, - stiff_1D, - mass_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "mass_1D_r", "mass_1D_s", - "stiff_1D_t"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - - partials = make_obj_array([ - weak_x, - weak_y, - weak_z - ]) - - elif grp.dim == 2: - weak_x = actx.einsum( - "rejb,ejb,ij,ab->eia", - inv_jac_mat_tp[0], - vec, - stiff_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "stiff_1D_r", "mass_1D_s"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - - weak_y = actx.einsum( - "rebj,ebj,ij,ab->eai", - inv_jac_mat_tp[1], + partials = make_obj_array([ + actx.einsum( + f"re{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + "ij," + + ("ab,cd" if grp.dim == 3 else "ab") + + "->" + f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", + inv_jac_mat_tp[i], vec, stiff_1D, - mass_1D, - arg_names=("inv_jac_mat", "vec", "mass_1D_r", "stiff_1D_s"), + *(mass_1D,)*(grp.dim-1), + arg_names=("inv_jac_mat", "vec", "stiff_1D", + *(("mass_1D_1", "mass_1D_2")[:grp.dim-1])), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) - - partials = make_obj_array([ - weak_x, - weak_y - ]) - - # FIXME: causes an error: static maximum not found for PwAff ... - # partials = make_obj_array([ - # actx.einsum( - # f"re{'bd'[:i]}j{'bd'[i:]}," + - # "ij,ab,cd->" + - # f"e{'ac'[:i]}i{'ac'[i:]}", - # vec, - # stiff_1D, - # mass_1D, - # mass_1D, - # arg_names=("inv_jac_mat", "vec", "stiff_1D", "mass_1D", - # "mass_1D"), - # tagged=(FirstAxisIsElementsTag(), - # OutputIsTensorProductDOFArrayOrdered())) - # for i in range(grp.dim) - # ]) + for i in range(grp.dim) + ]) # strong form else: - # FIXME: actually test that all of these dimensions work! (dim 2 and - # 3 work) partials = make_obj_array([ actx.einsum( f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + @@ -555,7 +435,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", inv_jac_mat_tp[i], diff_mat, - vec, + vec[i], arg_names=("inv_jac_mat", "diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) diff --git a/test/test_grudge.py b/test/test_grudge.py index 547ce8a2c..fd225cca0 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -34,6 +34,7 @@ from meshmode import _acf # noqa: F401 from meshmode.dof_array import flat_norm +from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup import meshmode.mesh.generation as mgen from pytools.obj_array import flat_obj_array @@ -428,31 +429,66 @@ def df(x, axis): # {{{ divergence theorem -def test_2d_gauss_theorem(actx_factory): +@pytest.mark.parametrize("group_cls", + [SimplexElementGroup, TensorProductElementGroup]) +def test_2d_gauss_theorem(actx_factory, group_cls): """Verify Gauss's theorem explicitly on a mesh""" - pytest.importorskip("meshpy") + actx = actx_factory() - from meshpy.geometry import make_circle, GeometryBuilder - from meshpy.triangle import MeshInfo, build + if group_cls is SimplexElementGroup: + pytest.importorskip("meshpy") - geob = GeometryBuilder() - geob.add_geometry(*make_circle(1)) - mesh_info = MeshInfo() - geob.set(mesh_info) + from meshpy.geometry import make_circle, GeometryBuilder + from meshpy.triangle import MeshInfo, build - mesh_info = build(mesh_info) + geob = GeometryBuilder() + geob.add_geometry(*make_circle(1)) + mesh_info = MeshInfo() + geob.set(mesh_info) - from meshmode.mesh.io import from_meshpy - from meshmode.mesh import BTAG_ALL + mesh_info = build(mesh_info) - mesh = from_meshpy(mesh_info, order=1) + from meshmode.mesh.io import from_meshpy + from meshmode.mesh import BTAG_ALL - actx = actx_factory() + mesh = from_meshpy(mesh_info, order=1) - dcoll = DiscretizationCollection(actx, mesh, order=2) - volm_disc = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - x_volm = actx.thaw(volm_disc.nodes()) + dcoll = DiscretizationCollection(actx, mesh, order=2) + volm_disc = dcoll.discr_from_dd(dof_desc.DD_VOLUME) + x_volm = actx.thaw(volm_disc.nodes()) + + elif group_cls is TensorProductElementGroup: + from meshmode.mesh.generation import generate_regular_rect_mesh + + dim = 2 + mesh = generate_regular_rect_mesh( + (-1,)*dim, (1,)*dim, nelements_per_axis=(2,)*dim, + group_cls=TensorProductElementGroup) + + alpha = 0.3 + rot_mat = np.array([ + [np.cos(alpha), np.sin(alpha)], + [-np.sin(alpha), np.cos(alpha)] + ]) + + from meshmode.mesh.processing import affine_map + mesh = affine_map(mesh, A=rot_mat) + + from meshmode.discretization.poly_element import \ + LegendreGaussLobattoTensorProductGroupFactory as LGL + dcoll = DiscretizationCollection( + actx, mesh, discr_tag_to_group_factory={ + dof_desc.DISCR_TAG_BASE: LGL(order=2) + } + ) + + volm_disc = dcoll.discr_from_dd(dof_desc.DD_VOLUME) + x_volm = actx.thaw(volm_disc.nodes()) + + else: + raise AssertionError('group_cls must be SimplexElementGroup or ' + f'TensorProductElementGroup. Found {group_cls}') def f(x): return flat_obj_array( @@ -463,6 +499,7 @@ def f(x): f_volm = f(x_volm) int_1 = op.integral(dcoll, "vol", op.local_div(dcoll, f_volm)) + from grudge.dof_desc import BTAG_ALL prj_f = op.project(dcoll, "vol", BTAG_ALL, f_volm) normal = geo.normal(actx, dcoll, BTAG_ALL) int_2 = op.integral(dcoll, BTAG_ALL, prj_f.dot(normal)) diff --git a/test/test_op.py b/test/test_op.py index 139c73866..767dd55e7 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -33,7 +33,6 @@ import pytest -from grudge.discretization import make_discretization_collection from grudge.array_context import PytestPyOpenCLArrayContextFactory from arraycontext import pytest_generate_tests_for_array_contexts pytest_generate_tests = pytest_generate_tests_for_array_contexts( @@ -241,6 +240,14 @@ def test_divergence(actx_factory, discr_type, form, dim, order, vectorize, neste dcoll = DiscretizationCollection(actx, mesh, order=order) + alpha = 0.3 + rot_mat = np.array([ + [np.cos(alpha), np.sin(alpha), 0], + [-np.sin(alpha), np.cos(alpha), 0], + [0, 0, 1], + ])[:dim, :dim] + + mesh = affine_map(mesh, A=rot_mat) def f(x): result = make_obj_array([dcoll.zeros(actx) + (i+1) for i in range(dim)]) for i in range(dim-1): From e1365389db7e61664752fb60746bf4675274cc9a Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 16 Dec 2023 14:20:31 -0600 Subject: [PATCH 45/66] Fixed tensor product operators, needs major refactor but it works for now --- grudge/geometry/metrics.py | 1 + grudge/op.py | 95 +++++++++++++++++++++++--------------- test/test_op.py | 84 ++++++++++++++++++--------------- 3 files changed, 106 insertions(+), 74 deletions(-) diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index c77e2de16..c33b0815e 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -527,6 +527,7 @@ def inverse_surface_metric_derivative_mat( @memoize_in(dcoll, (inverse_surface_metric_derivative_mat, dd, times_area_element, _use_geoderiv_connection)) def _inv_surf_metric_deriv(): + if times_area_element: multiplier = area_element(actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection) diff --git a/grudge/op.py b/grudge/op.py index 5f2377810..e9a422f70 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -281,7 +281,6 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # reshape u to expose tensor product structure vec = fold(grp.space, vec) - inv_jac_mat_tp = fold(grp.space, ijm) diff_mat = get_diff_mat(actx, grp, grp) # weak form case: @@ -293,17 +292,15 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, stiff_1D, mass_1D = diff_mat grad = make_obj_array([ actx.einsum( - f"re{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + "ij," + ("ab,cd" if grp.dim == 3 else "ab") + "->" f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", - inv_jac_mat_tp[i], vec, stiff_1D, *(mass_1D,)*(grp.dim-1), - arg_names=("inv_jac_mat", "vec", "stiff_1D", + arg_names=("vec", "stiff_1D", *(("mass_1D_1", "mass_1D_2")[:grp.dim-1])), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -315,25 +312,35 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, else: grad = make_obj_array([ actx.einsum( - f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + "yz," + f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", - inv_jac_mat_tp[i], diff_mat, vec, - arg_names=("inv_jac_mat", "diff_mat", "vec"), + arg_names=("diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) for i in range(grp.dim) ]) - # unreshape grad - grad = make_obj_array([ - unfold(grp.space, grad[i]) - for i in range(grp.dim) + # {{{ unreshape grad and apply geometric factors + + # TODO: Chain einsums together with geometric factors + grad = actx.np.stack([ + unfold(grp.space, grad[rst_axis]) + for rst_axis in range(grp.dim) ]) + grad = actx.einsum( + "xrej,rej->xej", + ijm, + grad, + arg_names=("inv_jac_mat", "grad"), + tagged=(FirstAxisIsElementsTag(),) + ) + + # }}} + return grad # }}} @@ -399,26 +406,26 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): metric_in_matvec) # reshape u to expose tensor product structure - vec = fold(grp.space, vec) - inv_jac_mat_tp = fold(grp.space, ijm) diff_mat = get_diff_mat(actx, grp, grp) + vec = make_obj_array([ + fold(grp.space, vec[xyz_axis]) + for xyz_axis in range(grp.dim) + ]) # weak form if metric_in_matvec: stiff_1D, mass_1D = diff_mat partials = make_obj_array([ actx.einsum( - f"re{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + "ij," + ("ab,cd" if grp.dim == 3 else "ab") + "->" f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", - inv_jac_mat_tp[i], - vec, + vec[i], stiff_1D, *(mass_1D,)*(grp.dim-1), - arg_names=("inv_jac_mat", "vec", "stiff_1D", + arg_names=("vec", "stiff_1D", *(("mass_1D_1", "mass_1D_2")[:grp.dim-1])), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) @@ -428,31 +435,47 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): # strong form else: partials = make_obj_array([ - actx.einsum( - f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + - "yz," + - f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + - f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", - inv_jac_mat_tp[i], - diff_mat, - vec[i], - arg_names=("inv_jac_mat", "diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for i in range(grp.dim) + make_obj_array([ + actx.einsum( + "yz," + + f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + + f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", + diff_mat, + vec[func_axis], + arg_names=("diff_mat", "vec"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + for i in range(grp.dim) + ]) + for func_axis in range(grp.dim) ]) - partials = make_obj_array([ - unfold(grp.space, partials[i]) - for i in range(grp.dim) + # {{{ unreshape, apply geometric factors, and sum over partials + + # TODO: Chain einsums together with geometric factors + partials = actx.np.stack([ + unfold(grp.space, partials[xyz_axis][rst_axis]) + for xyz_axis in range(grp.dim) + for rst_axis in range(grp.dim) ]) - partials = actx.np.stack([partials[i] for i in range(grp.dim)]) + try: + partials = partials.reshape( + grp.dim, grp.dim, partials.shape[1], partials.shape[2]) + except IndexError: + partials = partials.reshape( + grp.dim, grp.dim, partials.shape[1] + ) + div = actx.einsum( - "xei->ei", + "xrej,xrej->ej", + ijm, partials, - arg_names=("partials",), - tagged=(FirstAxisIsElementsTag(),)) + arg_names=("inv_jac_mat", "partials",), + tagged=(FirstAxisIsElementsTag(),) + ) + + # }}} return div diff --git a/test/test_op.py b/test/test_op.py index 767dd55e7..0e569e541 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -31,6 +31,8 @@ from grudge import op, geometry as geo, DiscretizationCollection from grudge.dof_desc import DOFDesc +from meshmode.mesh import SimplexElementGroup, TensorProductElementGroup + import pytest from grudge.array_context import PytestPyOpenCLArrayContextFactory @@ -45,7 +47,10 @@ # {{{ gradient -@pytest.mark.parametrize("discr_type", ["simplicial", "tensor product"]) +@pytest.mark.parametrize("group_cls", [ + # SimplexElementGroup, + TensorProductElementGroup +]) @pytest.mark.parametrize("form", ["strong", "weak"]) @pytest.mark.parametrize("dim", [1, 2, 3]) @pytest.mark.parametrize("order", [2, 3]) @@ -54,47 +59,40 @@ (True, False), (True, True) ]) -def test_gradient(actx_factory, discr_type, form, dim, order, vectorize, nested, - visualize=False): +def test_gradient(actx_factory, form, dim, order, vectorize, nested, + group_cls, visualize=False): actx = actx_factory() from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() for n in [4, 6, 8]: - if discr_type == "tensor product": + mesh = mgen.generate_regular_rect_mesh( + a=(-1,)*dim, b=(1,)*dim, + nelements_per_axis=(n,)*dim, + group_cls=group_cls) + + if group_cls is TensorProductElementGroup: # no reason to test 1D tensor product elements if dim == 1: return - from meshmode.mesh import TensorProductElementGroup + import grudge.dof_desc as dd from meshmode.discretization.poly_element import \ LegendreGaussLobattoTensorProductGroupFactory as LGL - mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, b=(1,)*dim, - nelements_per_axis=(n,)*dim, - group_cls=TensorProductElementGroup) - - import grudge.dof_desc as dd dcoll = DiscretizationCollection( actx, mesh, discr_tag_to_group_factory={ dd.DISCR_TAG_BASE: LGL(order)}) - else: - mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, b=(1,)*dim, - nelements_per_axis=(n,)*dim) + elif group_cls is SimplexElementGroup: dcoll = DiscretizationCollection(actx, mesh, order=order) - def f(x): - result = dcoll.zeros(actx) + 1 - for i in range(dim-1): - result = result * actx.np.sin(np.pi*x[i]) - result = result * actx.np.cos(np.pi/2*x[dim-1]) - return result + else: + raise AssertionError('Expecting TensorProductElementGroup or ' + f'SimplexElementGroup. Found {group_cls}') alpha = 0.3 rot_mat = np.array([ @@ -105,6 +103,13 @@ def f(x): mesh = affine_map(mesh, A=rot_mat) + def f(x): + result = dcoll.zeros(actx) + 1 + for i in range(dim-1): + result = result * actx.np.sin(np.pi*x[i]) + result = result * actx.np.cos(np.pi/2*x[dim-1]) + return result + def grad_f(x): result = make_obj_array([dcoll.zeros(actx) + 1 for _ in range(dim)]) for i in range(dim-1): @@ -196,50 +201,53 @@ def get_flux(u_tpair): # {{{ divergence -@pytest.mark.parametrize("discr_type", ["simplicial", "tensor_product"]) +@pytest.mark.parametrize("group_cls", [ + # SimplexElementGroup, + TensorProductElementGroup +]) @pytest.mark.parametrize("form", ["strong", "weak"]) -@pytest.mark.parametrize("dim", [1, 2, 3]) +@pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ (False, False), (True, False), (True, True) - ]) -def test_divergence(actx_factory, discr_type, form, dim, order, vectorize, nested, - visualize=False): +]) +def test_divergence(actx_factory, form, dim, order, vectorize, nested, + group_cls, visualize=False): actx = actx_factory() from pytools.convergence import EOCRecorder eoc_rec = EOCRecorder() for n in [4, 6, 8]: - if discr_type == "tensor product": + mesh = mgen.generate_regular_rect_mesh( + a=(-1,)*dim, b=(1,)*dim, + nelements_per_axis=(n,)*dim, + group_cls=group_cls) + + if group_cls is TensorProductElementGroup: # no reason to test 1D tensor product elements if dim == 1: return - from meshmode.mesh import TensorProductElementGroup + import grudge.dof_desc as dd from meshmode.discretization.poly_element import \ LegendreGaussLobattoTensorProductGroupFactory as LGL - mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, b=(1,)*dim, - nelements_per_axis=(n,)*dim, - group_cls=TensorProductElementGroup) - - import grudge.dof_desc as dd dcoll = DiscretizationCollection( actx, mesh, discr_tag_to_group_factory={ dd.DISCR_TAG_BASE: LGL(order)}) - else: - mesh = mgen.generate_regular_rect_mesh( - a=(-1,)*dim, b=(1,)*dim, - nelements_per_axis=(n,)*dim) + elif group_cls is SimplexElementGroup: dcoll = DiscretizationCollection(actx, mesh, order=order) + else: + raise AssertionError('Expecting TensorProductElementGroup or ' + f'SimplexElementGroup. Found {group_cls}') + alpha = 0.3 rot_mat = np.array([ [np.cos(alpha), np.sin(alpha), 0], From 82b2c093a86394b65cb71f401ede0440789863e3 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 16 Dec 2023 14:43:30 -0600 Subject: [PATCH 46/66] Update weak form div --- grudge/op.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index e9a422f70..b23c1a219 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -416,20 +416,23 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): if metric_in_matvec: stiff_1D, mass_1D = diff_mat partials = make_obj_array([ - actx.einsum( - f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + - "ij," + - ("ab,cd" if grp.dim == 3 else "ab") + - "->" - f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", - vec[i], - stiff_1D, - *(mass_1D,)*(grp.dim-1), - arg_names=("vec", "stiff_1D", - *(("mass_1D_1", "mass_1D_2")[:grp.dim-1])), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for i in range(grp.dim) + make_obj_array([ + actx.einsum( + f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + "ij," + + ("ab,cd" if grp.dim == 3 else "ab") + + "->" + f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", + vec[func_axis], + stiff_1D, + *(mass_1D,)*(grp.dim-1), + arg_names=("vec", "stiff_1D", + *(("mass_1D_1", "mass_1D_2")[:grp.dim-1])), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered())) + for i in range(grp.dim) + ]) + for func_axis in range(grp.dim) ]) # strong form From 2bb9228aeb23ec69da3c85d887c71b3dc742e50a Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 4 Jan 2024 10:54:32 -0600 Subject: [PATCH 47/66] Update pulse example. Update grudge tensor product gradient --- .../tensor-product-examples/acoustic_pulse.py | 18 ++-- grudge/op.py | 88 ++++++++----------- test/test_op.py | 2 +- 3 files changed, 51 insertions(+), 57 deletions(-) diff --git a/examples/tensor-product-examples/acoustic_pulse.py b/examples/tensor-product-examples/acoustic_pulse.py index 13c2194cf..8909a5109 100644 --- a/examples/tensor-product-examples/acoustic_pulse.py +++ b/examples/tensor-product-examples/acoustic_pulse.py @@ -111,7 +111,7 @@ def acoustic_pulse_condition(x_vec, t=0): def run_acoustic_pulse(actx, order=3, final_time=1, - resolution=4, + resolution=16, overintegration=False, visualize=False): @@ -122,7 +122,7 @@ def run_acoustic_pulse(actx, from meshmode.mesh.generation import generate_regular_rect_mesh - dim = 3 + dim = 2 box_ll = -0.5 box_ur = 0.5 mesh = generate_regular_rect_mesh( @@ -131,6 +131,15 @@ def run_acoustic_pulse(actx, nelements_per_axis=(resolution,)*dim, group_cls=TensorProductElementGroup) + if rotate_mesh: + from meshmode.mesh.processing import affine_map + alpha = .3 + rot_mat = np.array([ + [np.cos(alpha), np.sin(alpha)], + [-np.sin(alpha), np.cos(alpha)] + ]) + mesh = affine_map(mesh, A=rot_mat) + from grudge import DiscretizationCollection from grudge.dof_desc import DISCR_TAG_BASE, DISCR_TAG_QUAD from meshmode.discretization.poly_element import \ @@ -222,11 +231,10 @@ def main(ctx_factory, order=3, final_time=1, resolution=16, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), ) else: - from grudge.array_context import TensorProductArrayContext - actx = TensorProductArrayContext( + actx = PyOpenCLArrayContext( queue, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), - force_device_scalars=True, + force_device_scalars=False ) run_acoustic_pulse( diff --git a/grudge/op.py b/grudge/op.py index b23c1a219..e7e48db8c 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -246,8 +246,6 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec): - # TODO: add note about inverse mass simplification, point to - # op.inverse_mass (assuming this is where the explanation will live) """ Exploits tensor product structure to reduce complexity. Applies a differentiation operator containing 1D information to a tensor of DOF @@ -281,66 +279,54 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # reshape u to expose tensor product structure vec = fold(grp.space, vec) - diff_mat = get_diff_mat(actx, grp, grp) + ijm = fold(grp.space, ijm) - # weak form case: - # 3D weak_x: einsum("estu,ps,qt,ru->epqr", - # f, stiff_1D, mass_1D, mass_1D) - # TODO:? make this more general, maybe offload to a function that - # generates argnames and einsum specs if metric_in_matvec: - stiff_1D, mass_1D = diff_mat + stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) + + # TODO:? make this more general, maybe offload to a function that + # generates argnames and einsum specs grad = make_obj_array([ - actx.einsum( - f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + - "ij," + - ("ab,cd" if grp.dim == 3 else "ab") + - "->" - f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", - vec, - stiff_1D, - *(mass_1D,)*(grp.dim-1), - arg_names=("vec", "stiff_1D", - *(("mass_1D_1", "mass_1D_2")[:grp.dim-1])), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + unfold( + grp.space, + actx.einsum( + f"re{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + "ij," + + ("ab,cd" if grp.dim == 3 else "ab") + + "->" + f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", + ijm[i], + vec, + stiff_1d, + *(mass_1d,)*(grp.dim-1), + arg_names=("inv_jac_t", "vec", "stiff_1d", + *(("mass_1d_1", "mass_1d_2")[:grp.dim-1])), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered()))) for i in range(grp.dim) ]) - # Carries out, e.g., 3D strong form contraction - # x partial: einsum("il,eljk->eijk", D, f) else: + diff_mat = get_diff_mat(actx, grp, grp) + grad = make_obj_array([ - actx.einsum( - "yz," + - f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + - f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", - diff_mat, - vec, - arg_names=("diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) + unfold( + grp.space, + actx.einsum( + f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + + "yz," + + f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + + f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", + ijm[i], + diff_mat, + vec, + arg_names=("inv_jac_t", "diff_mat", "vec"), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered()))) for i in range(grp.dim) ]) - # {{{ unreshape grad and apply geometric factors - - # TODO: Chain einsums together with geometric factors - grad = actx.np.stack([ - unfold(grp.space, grad[rst_axis]) - for rst_axis in range(grp.dim) - ]) - - grad = actx.einsum( - "xrej,rej->xej", - ijm, - grad, - arg_names=("inv_jac_mat", "grad"), - tagged=(FirstAxisIsElementsTag(),) - ) - - # }}} - return grad # }}} diff --git a/test/test_op.py b/test/test_op.py index 0e569e541..f5b4d116e 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -48,7 +48,7 @@ # {{{ gradient @pytest.mark.parametrize("group_cls", [ - # SimplexElementGroup, + SimplexElementGroup, TensorProductElementGroup ]) @pytest.mark.parametrize("form", ["strong", "weak"]) From 7e389a7cc9a60240bd4d98e7e8756dc82f0c4bfb Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 8 Jan 2024 10:32:45 -0600 Subject: [PATCH 48/66] Start updating div to be a single einsum --- grudge/op.py | 69 +++++++++++++++++++++++++++++++++++++------------ test/test_op.py | 2 +- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index e7e48db8c..a5965f095 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -246,6 +246,8 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec): + # TODO: add note about inverse mass simplification, point to + # op.inverse_mass (assuming this is where the explanation will live) """ Exploits tensor product structure to reduce complexity. Applies a differentiation operator containing 1D information to a tensor of DOF @@ -391,53 +393,86 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): return compute_simplicial_div(actx, grp, grp, diff_mat, vec, ijm, metric_in_matvec) - # reshape u to expose tensor product structure - diff_mat = get_diff_mat(actx, grp, grp) vec = make_obj_array([ fold(grp.space, vec[xyz_axis]) for xyz_axis in range(grp.dim) ]) - # weak form if metric_in_matvec: - stiff_1D, mass_1D = diff_mat + stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) + # partials = make_obj_array([ + # make_obj_array([ + # actx.einsum( + # f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + # "ij," + + # ("ab,cd" if grp.dim == 3 else "ab") + + # "->" + # f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", + # vec[func_axis], + # stiff_1d, + # *(mass_1d,)*(grp.dim-1), + # arg_names=("vec", "stiff_1D", + # *(("mass_1d_1", "mass_1d_2")[:grp.dim-1])), + # tagged=(FirstAxisIsElementsTag(), + # OutputIsTensorProductDOFArrayOrdered())) + # for i in range(grp.dim) + # ]) + # for func_axis in range(grp.dim) + # ]) + partials = make_obj_array([ make_obj_array([ actx.einsum( - f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + f"xre{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + + f"xe{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + "ij," + ("ab,cd" if grp.dim == 3 else "ab") + "->" f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", - vec[func_axis], - stiff_1D, - *(mass_1D,)*(grp.dim-1), - arg_names=("vec", "stiff_1D", - *(("mass_1D_1", "mass_1D_2")[:grp.dim-1])), + ijm, + vec, + stiff_1d, + *(mass_1d,)*(grp.dim-1), + arg_names=("inv_jac_t", "vec", "stiff_1D", + *(("mass_1d_1", "mass_1d_2")[:grp.dim-1])), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) for i in range(grp.dim) ]) - for func_axis in range(grp.dim) ]) - # strong form else: + diff_mat = get_diff_mat(actx, grp, grp) + # partials = make_obj_array([ + # make_obj_array([ + # actx.einsum( + # "yz," + + # f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + + # f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", + # diff_mat, + # vec[func_axis], + # arg_names=("diff_mat", "vec"), + # tagged=(FirstAxisIsElementsTag(), + # OutputIsTensorProductDOFArrayOrdered())) + # for i in range(grp.dim) + # ]) + # for func_axis in range(grp.dim) + # ]) + partials = make_obj_array([ - make_obj_array([ actx.einsum( "yz," + + f"xre{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", + ijm, diff_mat, - vec[func_axis], - arg_names=("diff_mat", "vec"), + vec, + arg_names=("inv_jac_t", "diff_mat", "vec"), tagged=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered())) for i in range(grp.dim) ]) - for func_axis in range(grp.dim) - ]) # {{{ unreshape, apply geometric factors, and sum over partials diff --git a/test/test_op.py b/test/test_op.py index f5b4d116e..8e8330622 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -202,7 +202,7 @@ def get_flux(u_tpair): # {{{ divergence @pytest.mark.parametrize("group_cls", [ - # SimplexElementGroup, + #SimplexElementGroup, TensorProductElementGroup ]) @pytest.mark.parametrize("form", ["strong", "weak"]) From 9951d655267a44c284cd728337322838f1e615ac Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 16 Jan 2024 10:14:21 -0600 Subject: [PATCH 49/66] start move to single axis operator application --- grudge/op.py | 302 ++++++++++++++++++++++++++------------------------- 1 file changed, 152 insertions(+), 150 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index a5965f095..305cb2dda 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -190,13 +190,62 @@ def _single_axis_derivative_kernel( # {{{ tensor product single axis derivative - # FIXME: actually implement single axis tensor product derivatives def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, xyz_axis, metric_in_matvec): + vec = fold(grp.space, vec) + + if metric_in_matvec: + stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) + + apply_mass_axes = set(range(grp.dim)) - {xyz_axis} + + for ax in apply_mass_axes: + vec_mass_applied = single_axis_operator_application( + actx, grp.dim, mass_1d, ax, vec, + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("mass_1d", "vec") + ) + + ref_weak_derivative = unfold( + grp.space, + single_axis_operator_application( + actx, grp.dim, stiff_1d, xyz_axis, vec_mass_applied, + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("stiff_1d", "vec_with_mass_applied")) + ) + + derivative = actx.einsum( + 'rej,ej->ej', + ijm[xyz_axis], + ref_weak_derivative, + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", "ref_weak_derivative") + ) + + else: + diff_mat = get_diff_mat(actx, grp, grp) + + ref_derivative = unfold( + grp.space, + single_axis_operator_application( + actx, grp.dim, diff_mat, xyz_axis, vec, + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("diff_mat", "vec")) + ) + + derivative = actx.einsum( + 'rej,ej->ej', + ijm[xyz_axis], + ref_derivative, + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", "ref_derivs") + ) - return compute_simplicial_derivative(actx, grp, grp, get_diff_mat, vec, - ijm, xyz_axis, metric_in_matvec) + return derivative # }}} @@ -204,17 +253,17 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, # {{{ simplicial single axis derivative def compute_simplicial_derivative(actx, in_grp, out_grp, - get_diff_mat, vec_i, ijm_i, + get_diff_mat, vec, ijm, xyz_axis, metric_in_matvec): # r for rst axis return actx.einsum( "rej,rij,ej->ei" if metric_in_matvec else "rei,rij,ej->ei", - ijm_i[xyz_axis], + ijm[xyz_axis], get_diff_mat( actx, out_element_group=out_grp, in_element_group=in_grp), - vec_i, + vec, arg_names=("inv_jac_t", "ref_stiffT_mat", "vec", ), tagged=(FirstAxisIsElementsTag(),)) @@ -224,9 +273,8 @@ def compute_simplicial_derivative(actx, in_grp, out_grp, return DOFArray( actx, data=tuple( - compute_tensor_product_derivative(actx, in_grp, out_grp, - get_diff_mat, vec_i, ijm_i, - xyz_axis, metric_in_matvec) + compute_tensor_product_derivative(actx, in_grp, get_diff_mat, vec_i, + ijm_i, xyz_axis, metric_in_matvec) if isinstance(in_grp, TensorProductElementGroupBase) else compute_simplicial_derivative(actx, in_grp, out_grp, get_diff_mat, vec_i, ijm_i, @@ -279,57 +327,69 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, return compute_simplicial_grad(actx, grp, grp, diff_mat, vec, ijm, metric_in_matvec) - # reshape u to expose tensor product structure + # reshape vector to expose tensor product structure vec = fold(grp.space, vec) - ijm = fold(grp.space, ijm) if metric_in_matvec: stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) - # TODO:? make this more general, maybe offload to a function that - # generates argnames and einsum specs - grad = make_obj_array([ - unfold( + grad = [] + for xyz_axis in range(grp.dim): + grad.append(vec) + apply_mass_axes = set(range(grp.dim)) - {xyz_axis} + + # apply mass operators + for ax in apply_mass_axes: + grad[xyz_axis] = single_axis_operator_application( + actx, grp.dim, mass_1d, ax, grad[xyz_axis], + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("mass_1d", "vec") + ) + + # apply stiffness operator and unfold + grad[xyz_axis] = unfold( grp.space, - actx.einsum( - f"re{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + - f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + - "ij," + - ("ab,cd" if grp.dim == 3 else "ab") + - "->" - f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", - ijm[i], - vec, - stiff_1d, - *(mass_1d,)*(grp.dim-1), - arg_names=("inv_jac_t", "vec", "stiff_1d", - *(("mass_1d_1", "mass_1d_2")[:grp.dim-1])), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered()))) - for i in range(grp.dim) - ]) + single_axis_operator_application( + actx, grp.dim, stiff_1d, xyz_axis, grad[xyz_axis], + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("stiff_1d", "vec")) + ) + # apply metric terms + grad[xyz_axis] = actx.einsum( + 'rej,ej->ej', + ijm[xyz_axis], + grad[xyz_axis], + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", "vec") + ) else: diff_mat = get_diff_mat(actx, grp, grp) - grad = make_obj_array([ - unfold( + grad = [] + for xyz_axis in range(grp.dim): + grad.append(vec) + grad[xyz_axis] = unfold( grp.space, - actx.einsum( - f"re{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + - "yz," + - f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + - f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", - ijm[i], - diff_mat, - vec, - arg_names=("inv_jac_t", "diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered()))) - for i in range(grp.dim) - ]) - - return grad + single_axis_operator_application( + actx, grp.dim, diff_mat, xyz_axis, grad[xyz_axis], + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("diff_mat", "vec") + ) + ) + + grad[xyz_axis] = actx.einsum( + "rej,ej->ej", + ijm[xyz_axis], + grad[xyz_axis], + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", "vec") + ) + + return make_obj_array(grad) # }}} @@ -400,106 +460,22 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): if metric_in_matvec: stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) - # partials = make_obj_array([ - # make_obj_array([ - # actx.einsum( - # f"e{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + - # "ij," + - # ("ab,cd" if grp.dim == 3 else "ab") + - # "->" - # f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", - # vec[func_axis], - # stiff_1d, - # *(mass_1d,)*(grp.dim-1), - # arg_names=("vec", "stiff_1D", - # *(("mass_1d_1", "mass_1d_2")[:grp.dim-1])), - # tagged=(FirstAxisIsElementsTag(), - # OutputIsTensorProductDOFArrayOrdered())) - # for i in range(grp.dim) - # ]) - # for func_axis in range(grp.dim) - # ]) - - partials = make_obj_array([ - make_obj_array([ - actx.einsum( - f"xre{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + - f"xe{'bd'[:i]}j{'bd'[i:grp.dim-1]}," + - "ij," + - ("ab,cd" if grp.dim == 3 else "ab") + - "->" - f"e{'ac'[:i]}i{'ac'[i:grp.dim-1]}", - ijm, - vec, - stiff_1d, - *(mass_1d,)*(grp.dim-1), - arg_names=("inv_jac_t", "vec", "stiff_1D", - *(("mass_1d_1", "mass_1d_2")[:grp.dim-1])), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for i in range(grp.dim) - ]) - ]) + + partials = [] + for func_axis in range(grp.dim): + + partials.append(vec[func_axis]) + for xyz_axis in range(grp.dim): + + + div = 0 else: diff_mat = get_diff_mat(actx, grp, grp) - # partials = make_obj_array([ - # make_obj_array([ - # actx.einsum( - # "yz," + - # f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + - # f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", - # diff_mat, - # vec[func_axis], - # arg_names=("diff_mat", "vec"), - # tagged=(FirstAxisIsElementsTag(), - # OutputIsTensorProductDOFArrayOrdered())) - # for i in range(grp.dim) - # ]) - # for func_axis in range(grp.dim) - # ]) - - partials = make_obj_array([ - actx.einsum( - "yz," + - f"xre{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}," + - f"e{'abcdfghijkl'[:i]}z{'mnopqstuvwx'[:grp.dim-i-1]}->" + - f"e{'abcdfghijkl'[:i]}y{'mnopqstuvwx'[:grp.dim-i-1]}", - ijm, - diff_mat, - vec, - arg_names=("inv_jac_t", "diff_mat", "vec"), - tagged=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered())) - for i in range(grp.dim) - ]) - - # {{{ unreshape, apply geometric factors, and sum over partials - - # TODO: Chain einsums together with geometric factors - partials = actx.np.stack([ - unfold(grp.space, partials[xyz_axis][rst_axis]) - for xyz_axis in range(grp.dim) - for rst_axis in range(grp.dim) - ]) - - try: - partials = partials.reshape( - grp.dim, grp.dim, partials.shape[1], partials.shape[2]) - except IndexError: - partials = partials.reshape( - grp.dim, grp.dim, partials.shape[1] - ) - div = actx.einsum( - "xrej,xrej->ej", - ijm, - partials, - arg_names=("inv_jac_mat", "partials",), - tagged=(FirstAxisIsElementsTag(),) - ) + partials = [] - # }}} + div = 0 return div @@ -767,22 +743,22 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): vdm = mp.vandermonde(basis_1d.functions, nodes_1d) vdm_p = mp.vandermonde(basis_1d.gradients, nodes_1d)[0] - mass_1D = la.inv(vdm @ vdm.T) + mass_1d = la.inv(vdm @ vdm.T) diff_mat = la.solve(vdm.T, vdm_p.T).T - stiff_1D = actx.freeze( + stiff_1d = actx.freeze( actx.tag_axis(1, DiscretizationDOFAxisTag(), actx.from_numpy( np.asarray( - diff_mat.T @ mass_1D.T)))) + diff_mat.T @ mass_1d.T)))) - mass_1D = actx.freeze( + mass_1d = actx.freeze( actx.tag_axis(1, DiscretizationDOFAxisTag(), actx.from_numpy( np.asarray( - mass_1D)))) + mass_1d)))) - return (stiff_1D, mass_1D) + return (stiff_1d, mass_1d) # }}} @@ -1412,4 +1388,30 @@ def face_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: # }}} +# {{{ general single axis operator application + +def single_axis_operator_application(actx, dim, operator, axis, data, + arg_names=None, tags=None): + """ + Used for applying 1D operators to a single axis of a tensor of DOF data. + """ + + if not isinstance(arg_names, tuple): + raise TypeError('arg_names must be a tuple.') + if not isinstance(tags, tuple): + raise TypeError('arg_names must be a tuple.') + + operator_spec = 'ij' + data_spec = f'e{"abcdefghklm"[:axis]}j{"nopqrstuvwxyz"[:dim-axis-1]}' + out_spec = f'e{"abcdefghklm"[:axis]}i{"nopqrstuvwxyz"[:dim-axis-1]}' + + spec = operator_spec + ',' + data_spec + '->' + out_spec + + return actx.einsum(spec, operator, data, + arg_names=arg_names, + tagged=tags) + +# }}} + + # vim: foldmethod=marker From f722b2750b9cb911a078cd255c1fab315900835f Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Mon, 22 Jan 2024 09:03:17 -0600 Subject: [PATCH 50/66] Update div --- grudge/op.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 305cb2dda..b4b22bcfa 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -463,19 +463,61 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): partials = [] for func_axis in range(grp.dim): - partials.append(vec[func_axis]) - for xyz_axis in range(grp.dim): + apply_mass_axes = set(range(grp.dim)) - {func_axis} + for ax in apply_mass_axes: + partials[func_axis] = single_axis_operator_application( + actx, grp.dim, mass_1d, ax, partials[func_axis], + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("mass_1d", f"vec_{func_axis}")) - div = 0 + partials[func_axis] = unfold( + grp.space, + single_axis_operator_application( + actx, grp.dim, stiff_1d, func_axis, partials[func_axis], + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("stiff_1d", f"vec_{func_axis}")) + ) + + partials[func_axis] = actx.einsum( + "rej,ej->ej", + ijm[func_axis], + partials[func_axis], + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", f"partials_{func_axis}") + ) else: diff_mat = get_diff_mat(actx, grp, grp) partials = [] + for func_axis in range(grp.dim): + partials.append(vec[func_axis]) + + partials[func_axis] = unfold( + grp.space, + single_axis_operator_application( + actx, grp.dim, diff_mat, func_axis, partials[func_axis], + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("diff_mat", f"partials_{func_axis}") + ) + ) + + partials[func_axis] = actx.einsum( + "rej,ej->ej", + ijm[func_axis], + partials[func_axis], + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", f"partials_{func_axis}") + ) - div = 0 + div = 0 + for i in range(grp.dim): + div += partials[i] return div From 4b1e0b64d83de516f1689eba587124dde52bbfb3 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 27 Jan 2024 15:17:21 -0600 Subject: [PATCH 51/66] temp changes --- grudge/op.py | 59 +++++++++++++--------------------------------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index b4b22bcfa..030bbd028 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -462,58 +462,27 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) partials = [] - for func_axis in range(grp.dim): - partials.append(vec[func_axis]) - - apply_mass_axes = set(range(grp.dim)) - {func_axis} - for ax in apply_mass_axes: - partials[func_axis] = single_axis_operator_application( - actx, grp.dim, mass_1d, ax, partials[func_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), - arg_names=("mass_1d", f"vec_{func_axis}")) - - partials[func_axis] = unfold( - grp.space, - single_axis_operator_application( - actx, grp.dim, stiff_1d, func_axis, partials[func_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), - arg_names=("stiff_1d", f"vec_{func_axis}")) - ) - - partials[func_axis] = actx.einsum( - "rej,ej->ej", - ijm[func_axis], - partials[func_axis], - tagged=(FirstAxisIsElementsTag(),), - arg_names=("inv_jac_t", f"partials_{func_axis}") - ) else: diff_mat = get_diff_mat(actx, grp, grp) partials = [] - for func_axis in range(grp.dim): - partials.append(vec[func_axis]) - - partials[func_axis] = unfold( - grp.space, - single_axis_operator_application( - actx, grp.dim, diff_mat, func_axis, partials[func_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), - arg_names=("diff_mat", f"partials_{func_axis}") + for func_axis in range(vec.shape[0]): + partials.append([]) + for xyz_axis in range(grp.dim): + partials[func_axis] = unfold( + grp.space, + single_axis_operator_application( + actx, grp.dim, diff_mat, xyz_axis, vec, + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("diff_mat", f"vec_{xyz_axis}_{func_axis}") + ) ) - ) - partials[func_axis] = actx.einsum( - "rej,ej->ej", - ijm[func_axis], - partials[func_axis], - tagged=(FirstAxisIsElementsTag(),), - arg_names=("inv_jac_t", f"partials_{func_axis}") - ) + + + div = 0 for i in range(grp.dim): From 47d667425a69c1991fc1b23a04e057ee975f8f93 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sat, 27 Jan 2024 16:05:08 -0600 Subject: [PATCH 52/66] Working div recast as series of single axis operator applications --- grudge/op.py | 63 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 030bbd028..3f8e20617 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -454,42 +454,71 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): metric_in_matvec) vec = make_obj_array([ - fold(grp.space, vec[xyz_axis]) - for xyz_axis in range(grp.dim) + fold(grp.space, vec[func_axis]) + for func_axis in range(vec.shape[0]) ]) if metric_in_matvec: stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) partials = [] - - else: - diff_mat = get_diff_mat(actx, grp, grp) - - partials = [] for func_axis in range(vec.shape[0]): - partials.append([]) + ref = [] for xyz_axis in range(grp.dim): - partials[func_axis] = unfold( - grp.space, - single_axis_operator_application( - actx, grp.dim, diff_mat, xyz_axis, vec, + ref.append(vec[func_axis]) + + apply_mass_axes = set(range(grp.dim)) - {xyz_axis} + for ax in apply_mass_axes: + ref[xyz_axis] = single_axis_operator_application( + actx, grp.dim, mass_1d, ax, ref[xyz_axis], tags=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered(),), - arg_names=("diff_mat", f"vec_{xyz_axis}_{func_axis}") + arg_names=("mass_1d", f"vec_{func_axis}_{xyz_axis}") ) + + ref[xyz_axis] = single_axis_operator_application( + actx, grp.dim, stiff_1d, xyz_axis, ref[xyz_axis], + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("stiff_1d", f"vec_{func_axis}_{xyz_axis}") ) + partials.append(ref) + else: + diff_mat = get_diff_mat(actx, grp, grp) + partials = [] + for func_axis in range(vec.shape[0]): + ref = [] + for xyz_axis in range(grp.dim): + ref.append(vec[func_axis]) + ref[xyz_axis] = single_axis_operator_application( + actx, grp.dim, diff_mat, xyz_axis, ref[xyz_axis], + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("diff_mat", f"vec_{func_axis}_{xyz_axis}") + ) - div = 0 - for i in range(grp.dim): - div += partials[i] + partials.append(ref) - return div + partials = actx.np.stack([ + unfold(grp.space, partials[func_axis][xyz_axis]) + for func_axis in range(grp.dim) + for xyz_axis in range(grp.dim) + ]) + partials = partials.reshape(grp.dim, grp.dim, *partials.shape[-2:]) + + div = actx.einsum( + 'xrej,xrej->ej', + ijm, + partials, + arg_names=("inv_jac_t", "partials"), + tagged=(FirstAxisIsElementsTag(),) + ) + return div # }}} From 201c018060cd9e734c5855476d7e0100afcf3724 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 28 Jan 2024 11:04:18 -0600 Subject: [PATCH 53/66] Update _apply_inverse_mass_operator for TP elements --- grudge/array_context.py | 11 +++ grudge/op.py | 149 +++++++++++++++++++++++++--------------- 2 files changed, 105 insertions(+), 55 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index ab088c457..ac438bf17 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -647,6 +647,17 @@ class OutputIsTensorProductDOFArrayOrdered(Tag): """ pass + +class MassMatrix1d(Tag): + """Used in DAG transformation to realize algebraic simplification of 1D + inverse mass operator times mass operator. + """ + pass + +class InverseMassMatrix1d(Tag): + """See MassMatrix1d. + """ + # }}} # {{{ Eager TP array context diff --git a/grudge/op.py b/grudge/op.py index 3f8e20617..34af31c98 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -75,11 +75,13 @@ from functools import partial from meshmode.dof_array import DOFArray, warn +from meshmode.discretization.poly_element import ( + TensorProductElementGroupBase as TensorProductElementGroup, + SimplexElementGroupBase as SimplexElementGroup) from meshmode.transform_metadata import (FirstAxisIsElementsTag, DiscretizationDOFAxisTag, DiscretizationElementAxisTag, DiscretizationFaceAxisTag) -from meshmode.discretization.poly_element import TensorProductElementGroupBase from modepy.tools import ( reshape_array_for_tensor_product_space as fold, @@ -87,8 +89,7 @@ from grudge.discretization import DiscretizationCollection from grudge.dof_desc import as_dofdesc -from grudge.array_context import ( - OutputIsTensorProductDOFArrayOrdered) +from grudge.array_context import OutputIsTensorProductDOFArrayOrdered from pytools import keyed_memoize_in from pytools.obj_array import make_obj_array @@ -275,7 +276,7 @@ def compute_simplicial_derivative(actx, in_grp, out_grp, data=tuple( compute_tensor_product_derivative(actx, in_grp, get_diff_mat, vec_i, ijm_i, xyz_axis, metric_in_matvec) - if isinstance(in_grp, TensorProductElementGroupBase) + if isinstance(in_grp, TensorProductElementGroup) else compute_simplicial_derivative(actx, in_grp, out_grp, get_diff_mat, vec_i, ijm_i, xyz_axis, metric_in_matvec) @@ -297,28 +298,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, # TODO: add note about inverse mass simplification, point to # op.inverse_mass (assuming this is where the explanation will live) """ - Exploits tensor product structure to reduce complexity. Applies a - differentiation operator containing 1D information to a tensor of DOF - data. For example, in the 2D strong form case, this computes partial - derivatives in a similar manner to - - .. math:: - - \partial_x \mathbf{f}_{ij} = \sum_{\ell} \mathbf{J}^e_{ij} - \mathbf{D}_{i\ell} \mathbf{f}_{\ell j} - - where $\mathbf{D}$ is a 1D differentiation operator, $\mathbf{f}$ is a - vector of function data, $\mathbf{J}^e$ is the element Jacobian matrix. - The weak form uses a 1D element mass operator and a 1D element stiffness - operator to perform the contraction - - .. math:: - - \partial_x \mathbf{f}_{ij} = \sum_{\ell,b} \mathbf{J}^e_{\ell b} - \mathbf{f}_{\ell b} \mathbf{S}^e_{i\ell} \mathbf{M}^e_{jb} """ - if grp.dim > 3 and metric_in_matvec: warn('Efficient tensor product weak ' 'differentiation operators only ' @@ -344,7 +325,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, actx, grp.dim, mass_1d, ax, grad[xyz_axis], tags=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered(),), - arg_names=("mass_1d", "vec") + arg_names=("mass_1d", f"vec_{xyz_axis}") ) # apply stiffness operator and unfold @@ -354,7 +335,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, actx, grp.dim, stiff_1d, xyz_axis, grad[xyz_axis], tags=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered(),), - arg_names=("stiff_1d", "vec")) + arg_names=("stiff_1d", f"vec_{xyz_axis}")) ) # apply metric terms @@ -363,7 +344,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, ijm[xyz_axis], grad[xyz_axis], tagged=(FirstAxisIsElementsTag(),), - arg_names=("inv_jac_t", "vec") + arg_names=("inv_jac_t", f"vec_{xyz_axis}") ) else: diff_mat = get_diff_mat(actx, grp, grp) @@ -377,7 +358,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, actx, grp.dim, diff_mat, xyz_axis, grad[xyz_axis], tags=(FirstAxisIsElementsTag(), OutputIsTensorProductDOFArrayOrdered(),), - arg_names=("diff_mat", "vec") + arg_names=("diff_mat", f"vec_{xyz_axis}") ) ) @@ -386,7 +367,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, ijm[xyz_axis], grad[xyz_axis], tagged=(FirstAxisIsElementsTag(),), - arg_names=("inv_jac_t", "vec") + arg_names=("inv_jac_t", f"vec_{xyz_axis}") ) return make_obj_array(grad) @@ -416,7 +397,7 @@ def compute_simplicial_grad(actx, in_grp, out_grp, get_diff_mat, vec_i, per_group_grads = [ compute_tensor_product_grad(actx, in_grp, get_diff_mat, vec_i, ijm_i, metric_in_matvec) - if isinstance(in_grp, TensorProductElementGroupBase) + if isinstance(in_grp, TensorProductElementGroup) else compute_simplicial_grad(actx, in_grp, out_grp, get_diff_mat, vec_i, ijm_i, metric_in_matvec) @@ -544,7 +525,7 @@ def compute_simplicial_div(actx, in_grp, out_grp, get_diff_mat, vec_i, per_group_divs = [ compute_tensor_product_div(actx, in_grp, get_diff_mat, vec_i, ijm_i) - if isinstance(in_grp, TensorProductElementGroupBase) + if isinstance(in_grp, TensorProductElementGroup) # r for rst axis # x for xyz axis @@ -573,7 +554,7 @@ def _reference_derivative_matrices(actx: ArrayContext, actx, _reference_derivative_matrices, lambda grp: grp.discretization_key()) def get_ref_derivative_mats(grp): - if isinstance(grp, TensorProductElementGroupBase): + if isinstance(grp, TensorProductElementGroup): import modepy as mp import numpy.linalg as la @@ -593,13 +574,18 @@ def get_ref_derivative_mats(grp): 1: DiscretizationDOFAxisTag()}, diff_mat))) - else: + elif isinstance(grp, SimplexElementGroup): from meshmode.discretization.poly_element import diff_matrices return actx.freeze( actx.tag_axis( 1, DiscretizationDOFAxisTag(), actx.from_numpy( np.asarray(diff_matrices(grp))))) + + else: + raise TypeError("grp must be either a TensorProductElementGroup or" + f" a SimplexElementGroup. Found {grp}") + return get_ref_derivative_mats(out_element_group) @@ -772,7 +758,7 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): # {{{ tensor product case - if isinstance(out_grp, TensorProductElementGroupBase): + if isinstance(out_grp, TensorProductElementGroup): import modepy as mp import numpy.linalg as la @@ -792,11 +778,13 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): np.asarray( diff_mat.T @ mass_1d.T)))) + from grudge.array_context import MassMatrix1d mass_1d = actx.freeze( - actx.tag_axis(1, DiscretizationDOFAxisTag(), - actx.from_numpy( - np.asarray( - mass_1d)))) + actx.tag_axis( + 1, (DiscretizationDOFAxisTag(),), + actx.from_numpy(np.asarray(mass_1d))) + ) + mass_1d = actx.tag(MassMatrix1d(), mass_1d) return (stiff_1d, mass_1d) @@ -831,6 +819,7 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): ).copy() # contigify the array ) ) + return get_ref_stiffness_transpose_mat(out_element_group, in_element_group) @@ -1141,14 +1130,31 @@ def reference_inverse_mass_matrix(actx: ArrayContext, element_group): lambda grp: grp.discretization_key()) def get_ref_inv_mass_mat(grp): from modepy import inverse_mass_matrix - basis = grp.basis_obj() - return actx.freeze( - actx.tag_axis(0, DiscretizationDOFAxisTag(), - actx.from_numpy( - np.asarray( - inverse_mass_matrix(basis.functions, grp.unit_nodes), - order="C")))) + if isinstance(grp, TensorProductElementGroup): + basis_1d = grp.bases_1d() + nodes_1d = grp.unit_nodes_1d + inv_mass_1d = inverse_mass_matrix(basis_1d.functions, nodes_1d) + + from grudge.array_context import InverseMassMatrix1d + inv_mass_1d = actx.tag_axis(0, DiscretizationDOFAxisTag(), + actx.from_numpy(np.asarray(inv_mass_1d))) + inv_mass_1d = actx.freeze( + actx.tag(InverseMassMatrix1d(), inv_mass_1d)) + + return inv_mass_1d + elif isinstance(grp, SimplexElementGroup): + basis = grp.basis_obj() + + return actx.freeze( + actx.tag_axis(0, DiscretizationDOFAxisTag(), + actx.from_numpy( + np.asarray( + inverse_mass_matrix(basis.functions, grp.unit_nodes), + order="C")))) + else: + raise TypeError("grp must be either a TensorProductElementGroup or" + f" a SimplexElementGroup. Found {grp}") return get_ref_inv_mass_mat(element_group) @@ -1173,15 +1179,48 @@ def _apply_inverse_mass_operator( discr = dcoll.discr_from_dd(dd_in) inv_area_elements = 1./area_element(actx, dcoll, dd=dd_in, _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) + + def apply_to_tensor_product_elements(grp, jac_inv, vec, ref_inv_mass): + + vec = fold(grp.space, vec) + + for xyz_axis in range(grp.dim): + vec = single_axis_operator_application( + actx, grp.dim, ref_inv_mass, xyz_axis, vec, + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), + arg_names=("ref_inv_mass_1d", "vec")) + + vec = unfold(grp.space, vec) + + return actx.einsum( + "ei,ei->ei", + jac_inv, + vec, + tagged=(FirstAxisIsElementsTag(),) + ) + + + def apply_to_simplicial_elements(jac_inv, vec, ref_inv_mass): + + # Based on https://arxiv.org/pdf/1608.03836.pdf + # true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv + return actx.einsum( + "ei,ij,ej->ei", + jac_inv, + ref_inv_mass, + vec, + tagged=(FirstAxisIsElementsTag(),)) + group_data = [ - # Based on https://arxiv.org/pdf/1608.03836.pdf - # true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv - actx.einsum("ei,ij,ej->ei", - jac_inv, - reference_inverse_mass_matrix(actx, element_group=grp), - vec_i, - tagged=(FirstAxisIsElementsTag(),)) - for grp, jac_inv, vec_i in zip(discr.groups, inv_area_elements, vec)] + apply_to_tensor_product_elements( + grp, jac_inv, vec_i, + reference_inverse_mass_matrix(actx, element_group=grp)) + if isinstance(grp, TensorProductElementGroup) else + apply_to_simplicial_elements(jac_inv, vec_i, + reference_inverse_mass_matrix(actx, element_group=grp)) + for grp, jac_inv, vec_i in zip(discr.groups, inv_area_elements, vec) + ] return DOFArray(actx, data=tuple(group_data)) @@ -1437,9 +1476,9 @@ def single_axis_operator_application(actx, dim, operator, axis, data, """ if not isinstance(arg_names, tuple): - raise TypeError('arg_names must be a tuple.') + raise TypeError("arg_names must be a tuple.") if not isinstance(tags, tuple): - raise TypeError('arg_names must be a tuple.') + raise TypeError("arg_names must be a tuple.") operator_spec = 'ij' data_spec = f'e{"abcdefghklm"[:axis]}j{"nopqrstuvwxyz"[:dim-axis-1]}' From b2e909ee13e51aa39931acedc4af1c01979afdb2 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 28 Jan 2024 11:21:41 -0600 Subject: [PATCH 54/66] Minor formatting --- grudge/op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/grudge/op.py b/grudge/op.py index 34af31c98..b4d77ab67 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -1180,6 +1180,7 @@ def _apply_inverse_mass_operator( inv_area_elements = 1./area_element(actx, dcoll, dd=dd_in, _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) + def apply_to_tensor_product_elements(grp, jac_inv, vec, ref_inv_mass): vec = fold(grp.space, vec) @@ -1212,6 +1213,7 @@ def apply_to_simplicial_elements(jac_inv, vec, ref_inv_mass): vec, tagged=(FirstAxisIsElementsTag(),)) + group_data = [ apply_to_tensor_product_elements( grp, jac_inv, vec_i, From fdcf4b88191f190779fe6fee182e55a9a4796a49 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 11 Feb 2024 12:58:31 -0600 Subject: [PATCH 55/66] Add simplicial test back to div convergence test --- test/test_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_op.py b/test/test_op.py index 8e8330622..48d93a5a1 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -202,17 +202,17 @@ def get_flux(u_tpair): # {{{ divergence @pytest.mark.parametrize("group_cls", [ - #SimplexElementGroup, + SimplexElementGroup, TensorProductElementGroup ]) @pytest.mark.parametrize("form", ["strong", "weak"]) -@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize("dim", [1, 2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ (False, False), (True, False), (True, True) -]) + ]) def test_divergence(actx_factory, form, dim, order, vectorize, nested, group_cls, visualize=False): actx = actx_factory() From d9dbffe1e6cab119fbf2be90a0bea82cdae34d62 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 29 Feb 2024 14:31:21 -0600 Subject: [PATCH 56/66] Start working on TP DOF axis tagging (nothing working yet) --- test/test_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_op.py b/test/test_op.py index 48d93a5a1..bb6930889 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -48,11 +48,11 @@ # {{{ gradient @pytest.mark.parametrize("group_cls", [ - SimplexElementGroup, + #SimplexElementGroup, TensorProductElementGroup ]) -@pytest.mark.parametrize("form", ["strong", "weak"]) -@pytest.mark.parametrize("dim", [1, 2, 3]) +@pytest.mark.parametrize("form", ["strong"]) +@pytest.mark.parametrize("dim", [2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ (False, False), From d6110236af4df1f71353764f0048b99f2ff67c3e Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Thu, 29 Feb 2024 14:32:08 -0600 Subject: [PATCH 57/66] another TP DOF axis tagging commit (still nothing working yet) --- grudge/array_context.py | 19 +++++++++++++++++- grudge/op.py | 43 +++++++++++++++++++++++++++-------------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index ac438bf17..19ff4f4c5 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -105,6 +105,10 @@ import pyopencl.tools from mpi4py import MPI +# }}} + + +# {{{ pyopencl class PyOpenCLArrayContext(_PyOpenCLArrayContextBase): """Inherits from :class:`meshmode.array_context.PyOpenCLArrayContext`. Extends it @@ -130,12 +134,13 @@ def transform_loopy_program(self, t_unit): if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): new_args = [] + for arg in knl.args: if arg.is_output: arg = arg.copy(dim_tags=( f"N{len(arg.shape)-1}," + ",".join(f"N{i}" - for i in range(len(arg.shape)-1)) + for i in range(len(arg.shape)-1)) )) new_args.append(arg) @@ -634,6 +639,7 @@ def get_reasonable_array_context_class( # }}} + # {{{ tensor product-specific machinery class OutputIsTensorProductDOFArrayOrdered(Tag): @@ -647,6 +653,14 @@ class OutputIsTensorProductDOFArrayOrdered(Tag): """ pass +class TensorProductDOFAxis(Tag): + """ + Tag an axis as being an axis containing the DOFs of a tensor-product + discretization. Used to signify that the strides associated with the array + containing this axis will be neither column nor row major. + """ + pass + class MassMatrix1d(Tag): """Used in DAG transformation to realize algebraic simplification of 1D @@ -660,18 +674,21 @@ class InverseMassMatrix1d(Tag): # }}} + # {{{ Eager TP array context class TensorProductArrayContext(_PyOpenCLArrayContextBase): """Specialized array context for use with tensor product elements. """ # }}} + # {{{ Lazy tensor product array context class PytatoTensorProductArrayContext(PytatoPyOpenCLArrayContext): def transform_dag(self, dag): return super().transform_dag(dag) # }}} + # }}} # vim: foldmethod=marker diff --git a/grudge/op.py b/grudge/op.py index b4d77ab67..e9980fa70 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -89,7 +89,10 @@ from grudge.discretization import DiscretizationCollection from grudge.dof_desc import as_dofdesc -from grudge.array_context import OutputIsTensorProductDOFArrayOrdered +from grudge.array_context import ( + OutputIsTensorProductDOFArrayOrdered, + TensorProductDOFAxis +) from pytools import keyed_memoize_in from pytools.obj_array import make_obj_array @@ -194,7 +197,12 @@ def _single_axis_derivative_kernel( def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, xyz_axis, metric_in_matvec): - vec = fold(grp.space, vec) + vec = tag_axes( + actx, + { i: TensorProductDOFAxis() for i in range(1, grp.dim) }, + fold(grp.space, vec) + ) + if metric_in_matvec: stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) @@ -204,8 +212,7 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, for ax in apply_mass_axes: vec_mass_applied = single_axis_operator_application( actx, grp.dim, mass_1d, ax, vec, - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tags=(FirstAxisIsElementsTag(),), arg_names=("mass_1d", "vec") ) @@ -213,8 +220,7 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, stiff_1d, xyz_axis, vec_mass_applied, - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tags=(FirstAxisIsElementsTag(),), arg_names=("stiff_1d", "vec_with_mass_applied")) ) @@ -233,8 +239,7 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, diff_mat, xyz_axis, vec, - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tags=(FirstAxisIsElementsTag(),), arg_names=("diff_mat", "vec")) ) @@ -309,7 +314,11 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec) # reshape vector to expose tensor product structure - vec = fold(grp.space, vec) + vec = tag_axes( + actx, + { i: TensorProductDOFAxis() for i in range(1, grp.dim+1) }, + fold(grp.space, vec) + ) if metric_in_matvec: stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) @@ -323,8 +332,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, for ax in apply_mass_axes: grad[xyz_axis] = single_axis_operator_application( actx, grp.dim, mass_1d, ax, grad[xyz_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tags=(FirstAxisIsElementsTag(),), arg_names=("mass_1d", f"vec_{xyz_axis}") ) @@ -333,8 +341,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, stiff_1d, xyz_axis, grad[xyz_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tags=(FirstAxisIsElementsTag(),), arg_names=("stiff_1d", f"vec_{xyz_axis}")) ) @@ -1480,7 +1487,7 @@ def single_axis_operator_application(actx, dim, operator, axis, data, if not isinstance(arg_names, tuple): raise TypeError("arg_names must be a tuple.") if not isinstance(tags, tuple): - raise TypeError("arg_names must be a tuple.") + raise TypeError("tags must be a tuple.") operator_spec = 'ij' data_spec = f'e{"abcdefghklm"[:axis]}j{"nopqrstuvwxyz"[:dim-axis-1]}' @@ -1488,9 +1495,15 @@ def single_axis_operator_application(actx, dim, operator, axis, data, spec = operator_spec + ',' + data_spec + '->' + out_spec - return actx.einsum(spec, operator, data, + result = tag_axes( + actx, + { i: TensorProductDOFAxis() for i in range(1, dim+1) }, + actx.einsum(spec, operator, data, arg_names=arg_names, tagged=tags) + ) + + return result # }}} From f9a53301e37dcf56754eb61f995d663d4235ed68 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 1 Mar 2024 11:20:15 -0600 Subject: [PATCH 58/66] Start implementing call_loopy in grudge eager actx --- grudge/array_context.py | 31 ++++++++++++++++++++++++++++++- grudge/op.py | 41 +++++++++++++++++------------------------ test/test_op.py | 8 ++++---- 3 files changed, 51 insertions(+), 29 deletions(-) diff --git a/grudge/array_context.py b/grudge/array_context.py index 19ff4f4c5..ec4aeae9f 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -134,7 +134,6 @@ def transform_loopy_program(self, t_unit): if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): new_args = [] - for arg in knl.args: if arg.is_output: arg = arg.copy(dim_tags=( @@ -152,6 +151,36 @@ def transform_loopy_program(self, t_unit): return super().transform_loopy_program(t_unit) + + def call_loopy(self, t_unit, **kwargs): + # NOTE: modifying strides pertaining to tensor product axes is done here + # since that information is not available in the arguments passed to + # `transform_loopy_program` in eager evaluation + + default_ep = t_unit.default_entrypoint + + # {{{ process kwargs with TP axis tags + + if default_ep.tags_of_type(OutputIsTensorProductDOFArrayOrdered): + new_args = [] + + for arg_name in kwargs.keys(): + kwarg = kwargs[arg_name].axes + + for axis in kwarg.axes: + if axis.tags_of_type(TensorProductDOFAxis): + if arg_name in default_ep.arg_dict.keys(): + arg = default_ep.arg_dict[arg_name] + + dim_tags = ( + + ) + + + # }}} + + return super().call_loopy(t_unit, **kwargs) + # }}} diff --git a/grudge/op.py b/grudge/op.py index e9980fa70..cd9b3d08e 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -212,7 +212,8 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, for ax in apply_mass_axes: vec_mass_applied = single_axis_operator_application( actx, grp.dim, mass_1d, ax, vec, - tags=(FirstAxisIsElementsTag(),), + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("mass_1d", "vec") ) @@ -220,7 +221,8 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, stiff_1d, xyz_axis, vec_mass_applied, - tags=(FirstAxisIsElementsTag(),), + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("stiff_1d", "vec_with_mass_applied")) ) @@ -239,7 +241,8 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, diff_mat, xyz_axis, vec, - tags=(FirstAxisIsElementsTag(),), + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("diff_mat", "vec")) ) @@ -300,19 +303,11 @@ def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec): - # TODO: add note about inverse mass simplification, point to - # op.inverse_mass (assuming this is where the explanation will live) """ + Applies 1D operators one-axis-at-a-time to tensor-product discretized + DOF data. """ - if grp.dim > 3 and metric_in_matvec: - warn('Efficient tensor product weak ' - 'differentiation operators only ' - 'implemented for dimension 2 and 3. ' - 'Defaulting to inefficient version.') - return compute_simplicial_grad(actx, grp, grp, diff_mat, vec, ijm, - metric_in_matvec) - # reshape vector to expose tensor product structure vec = tag_axes( actx, @@ -332,7 +327,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, for ax in apply_mass_axes: grad[xyz_axis] = single_axis_operator_application( actx, grp.dim, mass_1d, ax, grad[xyz_axis], - tags=(FirstAxisIsElementsTag(),), + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("mass_1d", f"vec_{xyz_axis}") ) @@ -341,7 +337,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, stiff_1d, xyz_axis, grad[xyz_axis], - tags=(FirstAxisIsElementsTag(),), + tags=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("stiff_1d", f"vec_{xyz_axis}")) ) @@ -433,16 +430,12 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): `_gradient_kernel.compute_tensor_product_grad` for more details. """ - if grp.dim > 3 and metric_in_matvec: - warn('Efficient tensor product weak ' - 'differentiation operators only ' - 'implemented for dimension 2 and 3. ' - 'Defaulting to inefficient version.') - return compute_simplicial_div(actx, grp, grp, diff_mat, vec, ijm, - metric_in_matvec) - vec = make_obj_array([ - fold(grp.space, vec[func_axis]) + tag_axes( + actx, + { i: TensorProductDOFAxis() for i in range(1,grp.dim+1) }, + fold(grp.space, vec[func_axis]) + ) for func_axis in range(vec.shape[0]) ]) diff --git a/test/test_op.py b/test/test_op.py index bb6930889..c12b674f0 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -48,11 +48,11 @@ # {{{ gradient @pytest.mark.parametrize("group_cls", [ - #SimplexElementGroup, + SimplexElementGroup, TensorProductElementGroup ]) -@pytest.mark.parametrize("form", ["strong"]) -@pytest.mark.parametrize("dim", [2, 3]) +@pytest.mark.parametrize("form", ["strong", "weak"]) +@pytest.mark.parametrize("dim", [1, 2, 3]) @pytest.mark.parametrize("order", [2, 3]) @pytest.mark.parametrize(("vectorize", "nested"), [ (False, False), @@ -75,7 +75,7 @@ def test_gradient(actx_factory, form, dim, order, vectorize, nested, if group_cls is TensorProductElementGroup: # no reason to test 1D tensor product elements if dim == 1: - return + pytest.skip() import grudge.dof_desc as dd from meshmode.discretization.poly_element import \ From f21635528c8095d7bc4137228ebd0a1d29f30b7a Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Fri, 1 Mar 2024 14:06:16 -0600 Subject: [PATCH 59/66] Revert eager changes. Remove unnecessary code in lazy actx. Update example --- .../tensor-product-examples/acoustic_pulse.py | 6 +- grudge/array_context.py | 76 ++----------------- grudge/op.py | 49 ++++++------ 3 files changed, 34 insertions(+), 97 deletions(-) diff --git a/examples/tensor-product-examples/acoustic_pulse.py b/examples/tensor-product-examples/acoustic_pulse.py index 8909a5109..9520945f3 100644 --- a/examples/tensor-product-examples/acoustic_pulse.py +++ b/examples/tensor-product-examples/acoustic_pulse.py @@ -113,7 +113,8 @@ def run_acoustic_pulse(actx, final_time=1, resolution=16, overintegration=False, - visualize=False): + visualize=False, + rotate_mesh=False): # eos-related parameters gamma = 1.4 @@ -225,8 +226,7 @@ def main(ctx_factory, order=3, final_time=1, resolution=16, queue = cl.CommandQueue(cl_ctx) if lazy: - from grudge.array_context import PytatoTensorProductArrayContext - actx = PytatoTensorProductArrayContext( + actx = PytatoPyOpenCLArrayContext( queue, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), ) diff --git a/grudge/array_context.py b/grudge/array_context.py index ec4aeae9f..12e97361b 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -132,6 +132,9 @@ def transform_loopy_program(self, t_unit): # {{{ process tensor product specific metadata + # NOTE: This differs from the lazy case b/c we don't have access to axis + # tags that can be manipulated pre-execution. In eager, we update + # strides/loop nest ordering for the output array if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): new_args = [] for arg in knl.args: @@ -151,36 +154,6 @@ def transform_loopy_program(self, t_unit): return super().transform_loopy_program(t_unit) - - def call_loopy(self, t_unit, **kwargs): - # NOTE: modifying strides pertaining to tensor product axes is done here - # since that information is not available in the arguments passed to - # `transform_loopy_program` in eager evaluation - - default_ep = t_unit.default_entrypoint - - # {{{ process kwargs with TP axis tags - - if default_ep.tags_of_type(OutputIsTensorProductDOFArrayOrdered): - new_args = [] - - for arg_name in kwargs.keys(): - kwarg = kwargs[arg_name].axes - - for axis in kwarg.axes: - if axis.tags_of_type(TensorProductDOFAxis): - if arg_name in default_ep.arg_dict.keys(): - arg = default_ep.arg_dict[arg_name] - - dim_tags = ( - - ) - - - # }}} - - return super().call_loopy(t_unit, **kwargs) - # }}} @@ -209,29 +182,6 @@ def __init__(self, queue, allocator=None, super().__init__(queue, allocator, compile_trace_callback=compile_trace_callback) - def transform_loopy_program(self, t_unit): - knl = t_unit.default_entrypoint - - # {{{ process tensor product specific metadata - - if knl.tags_of_type(OutputIsTensorProductDOFArrayOrdered): - new_args = [] - for arg in knl.args: - if arg.is_output: - arg = arg.copy(dim_tags=( - f"N{len(arg.shape)-1}," - + ",".join(f"N{i}" - for i in range(len(arg.shape)-1)) - )) - - new_args.append(arg) - - knl = knl.copy(args=new_args) - - # }}} - - return super().transform_loopy_program(t_unit) - # }}} @@ -669,7 +619,7 @@ def get_reasonable_array_context_class( # }}} -# {{{ tensor product-specific machinery +# {{{ tensor product discretization metadata class OutputIsTensorProductDOFArrayOrdered(Tag): """Signify that the strides will not be of order "C" or "F". See @@ -682,6 +632,7 @@ class OutputIsTensorProductDOFArrayOrdered(Tag): """ pass + class TensorProductDOFAxis(Tag): """ Tag an axis as being an axis containing the DOFs of a tensor-product @@ -697,6 +648,7 @@ class MassMatrix1d(Tag): """ pass + class InverseMassMatrix1d(Tag): """See MassMatrix1d. """ @@ -704,20 +656,4 @@ class InverseMassMatrix1d(Tag): # }}} -# {{{ Eager TP array context -class TensorProductArrayContext(_PyOpenCLArrayContextBase): - """Specialized array context for use with tensor product elements. - """ -# }}} - - -# {{{ Lazy tensor product array context -class PytatoTensorProductArrayContext(PytatoPyOpenCLArrayContext): - def transform_dag(self, dag): - return super().transform_dag(dag) -# }}} - - -# }}} - # vim: foldmethod=marker diff --git a/grudge/op.py b/grudge/op.py index cd9b3d08e..3be95be7d 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -212,8 +212,8 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, for ax in apply_mass_axes: vec_mass_applied = single_axis_operator_application( actx, grp.dim, mass_1d, ax, vec, - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("mass_1d", "vec") ) @@ -221,8 +221,8 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, stiff_1d, xyz_axis, vec_mass_applied, - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("stiff_1d", "vec_with_mass_applied")) ) @@ -241,8 +241,8 @@ def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, diff_mat, xyz_axis, vec, - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("diff_mat", "vec")) ) @@ -327,8 +327,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, for ax in apply_mass_axes: grad[xyz_axis] = single_axis_operator_application( actx, grp.dim, mass_1d, ax, grad[xyz_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("mass_1d", f"vec_{xyz_axis}") ) @@ -337,8 +337,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, stiff_1d, xyz_axis, grad[xyz_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("stiff_1d", f"vec_{xyz_axis}")) ) @@ -360,8 +360,8 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, grp.space, single_axis_operator_application( actx, grp.dim, diff_mat, xyz_axis, grad[xyz_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("diff_mat", f"vec_{xyz_axis}") ) ) @@ -452,15 +452,15 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): for ax in apply_mass_axes: ref[xyz_axis] = single_axis_operator_application( actx, grp.dim, mass_1d, ax, ref[xyz_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("mass_1d", f"vec_{func_axis}_{xyz_axis}") ) ref[xyz_axis] = single_axis_operator_application( actx, grp.dim, stiff_1d, xyz_axis, ref[xyz_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("stiff_1d", f"vec_{func_axis}_{xyz_axis}") ) @@ -477,8 +477,8 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): ref[xyz_axis] = single_axis_operator_application( actx, grp.dim, diff_mat, xyz_axis, ref[xyz_axis], - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("diff_mat", f"vec_{func_axis}_{xyz_axis}") ) @@ -500,6 +500,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): ) return div + # }}} @@ -1188,8 +1189,8 @@ def apply_to_tensor_product_elements(grp, jac_inv, vec, ref_inv_mass): for xyz_axis in range(grp.dim): vec = single_axis_operator_application( actx, grp.dim, ref_inv_mass, xyz_axis, vec, - tags=(FirstAxisIsElementsTag(), - OutputIsTensorProductDOFArrayOrdered(),), + tagged=(FirstAxisIsElementsTag(), + OutputIsTensorProductDOFArrayOrdered(),), arg_names=("ref_inv_mass_1d", "vec")) vec = unfold(grp.space, vec) @@ -1472,15 +1473,15 @@ def face_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: # {{{ general single axis operator application def single_axis_operator_application(actx, dim, operator, axis, data, - arg_names=None, tags=None): + arg_names=None, tagged=None): """ Used for applying 1D operators to a single axis of a tensor of DOF data. """ if not isinstance(arg_names, tuple): raise TypeError("arg_names must be a tuple.") - if not isinstance(tags, tuple): - raise TypeError("tags must be a tuple.") + if not isinstance(tagged, tuple): + raise TypeError("tagged must be a tuple.") operator_spec = 'ij' data_spec = f'e{"abcdefghklm"[:axis]}j{"nopqrstuvwxyz"[:dim-axis-1]}' @@ -1493,7 +1494,7 @@ def single_axis_operator_application(actx, dim, operator, axis, data, { i: TensorProductDOFAxis() for i in range(1, dim+1) }, actx.einsum(spec, operator, data, arg_names=arg_names, - tagged=tags) + tagged=tagged) ) return result From 90d70925cab40e98d838653692c72570c28430fd Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 2 Apr 2024 12:38:12 -0500 Subject: [PATCH 60/66] Update TP operator and data tagging --- .../tensor-product-examples/acoustic_pulse.py | 2 + grudge/array_context.py | 41 +----- grudge/op.py | 139 +++++++++--------- grudge/transform/metadata.py | 60 ++++++++ 4 files changed, 134 insertions(+), 108 deletions(-) create mode 100644 grudge/transform/metadata.py diff --git a/examples/tensor-product-examples/acoustic_pulse.py b/examples/tensor-product-examples/acoustic_pulse.py index 9520945f3..9eb1b1ce7 100644 --- a/examples/tensor-product-examples/acoustic_pulse.py +++ b/examples/tensor-product-examples/acoustic_pulse.py @@ -160,6 +160,8 @@ def run_acoustic_pulse(actx, } ) + print(actx) + # }}} # {{{ Euler operator diff --git a/grudge/array_context.py b/grudge/array_context.py index 12e97361b..0ad5b038c 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -36,6 +36,10 @@ FrozenSet) from dataclasses import dataclass from pytools.tag import Tag + +from grudge.transform.metadata import OutputIsTensorProductDOFArrayOrdered + +from meshmode.transform_metadata import DiscretizationDOFAxisTag from meshmode.array_context import ( PyOpenCLArrayContext as _PyOpenCLArrayContextBase, PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase) @@ -619,41 +623,4 @@ def get_reasonable_array_context_class( # }}} -# {{{ tensor product discretization metadata - -class OutputIsTensorProductDOFArrayOrdered(Tag): - """Signify that the strides will not be of order "C" or "F". See - :class:`grudge.array_context.TensorProductArrayContext` for more details. - - The strides for the arrays containing tensor product element data are of the - form (slow, fastest, faster, fast). These strides are not "C" or "F" order. - Hence, this specialized array context takes care of specifying the - particular strides required. - """ - pass - - -class TensorProductDOFAxis(Tag): - """ - Tag an axis as being an axis containing the DOFs of a tensor-product - discretization. Used to signify that the strides associated with the array - containing this axis will be neither column nor row major. - """ - pass - - -class MassMatrix1d(Tag): - """Used in DAG transformation to realize algebraic simplification of 1D - inverse mass operator times mass operator. - """ - pass - - -class InverseMassMatrix1d(Tag): - """See MassMatrix1d. - """ - -# }}} - - # vim: foldmethod=marker diff --git a/grudge/op.py b/grudge/op.py index 3be95be7d..14f827ba1 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -78,7 +78,7 @@ from meshmode.discretization.poly_element import ( TensorProductElementGroupBase as TensorProductElementGroup, SimplexElementGroupBase as SimplexElementGroup) -from meshmode.transform_metadata import (FirstAxisIsElementsTag, +from meshmode.transform_metadata import (DiscretizationAmbientDimAxisTag, FirstAxisIsElementsTag, DiscretizationDOFAxisTag, DiscretizationElementAxisTag, DiscretizationFaceAxisTag) @@ -89,9 +89,11 @@ from grudge.discretization import DiscretizationCollection from grudge.dof_desc import as_dofdesc -from grudge.array_context import ( +from grudge.transform.metadata import ( OutputIsTensorProductDOFArrayOrdered, - TensorProductDOFAxis + TensorProductDOFAxisTag, + MassMatrix1DTag, + InverseMassMatrix1DTag ) from pytools import keyed_memoize_in @@ -179,6 +181,8 @@ # {{{ common derivative "kernels" +# {{{ single axis derivative + def _single_axis_derivative_kernel( actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, xyz_axis, vec, *, metric_in_matvec): @@ -197,12 +201,7 @@ def _single_axis_derivative_kernel( def compute_tensor_product_derivative(actx, grp, get_diff_mat, vec, ijm, xyz_axis, metric_in_matvec): - vec = tag_axes( - actx, - { i: TensorProductDOFAxis() for i in range(1, grp.dim) }, - fold(grp.space, vec) - ) - + vec = fold(grp.space, vec) if metric_in_matvec: stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) @@ -292,14 +291,17 @@ def compute_simplicial_derivative(actx, in_grp, out_grp, out_discr.groups, in_discr.groups, vec, inv_jac_mat))) +# }}} + + +# {{{ gradient def _gradient_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, *, metric_in_matvec): # See _single_axis_derivative_kernel for comments on the usage scenarios # (both strong and weak derivative) and their differences. - - # {{{ tensor product gradient + # {{{ tensor product grad def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, metric_in_matvec): @@ -309,11 +311,7 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, """ # reshape vector to expose tensor product structure - vec = tag_axes( - actx, - { i: TensorProductDOFAxis() for i in range(1, grp.dim+1) }, - fold(grp.space, vec) - ) + vec = fold(grp.space, vec) if metric_in_matvec: stiff_1d, mass_1d = get_diff_mat(actx, grp, grp) @@ -342,14 +340,6 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, arg_names=("stiff_1d", f"vec_{xyz_axis}")) ) - # apply metric terms - grad[xyz_axis] = actx.einsum( - 'rej,ej->ej', - ijm[xyz_axis], - grad[xyz_axis], - tagged=(FirstAxisIsElementsTag(),), - arg_names=("inv_jac_t", f"vec_{xyz_axis}") - ) else: diff_mat = get_diff_mat(actx, grp, grp) @@ -366,15 +356,21 @@ def compute_tensor_product_grad(actx, grp, diff_mat, vec, ijm, ) ) - grad[xyz_axis] = actx.einsum( - "rej,ej->ej", - ijm[xyz_axis], - grad[xyz_axis], - tagged=(FirstAxisIsElementsTag(),), - arg_names=("inv_jac_t", f"vec_{xyz_axis}") - ) - - return make_obj_array(grad) + grad = actx.np.stack(grad) + return tag_axes( + actx, + { + 0: DiscretizationAmbientDimAxisTag(), + 1: DiscretizationElementAxisTag(), + 2: DiscretizationDOFAxisTag() + }, + actx.einsum( + "xrej,rej->xej", + ijm, + grad, + tagged=(FirstAxisIsElementsTag(),), + arg_names=("inv_jac_t", f"vec") + )) # }}} @@ -415,6 +411,10 @@ def compute_simplicial_grad(actx, in_grp, out_grp, get_diff_mat, vec_i, actx, data=tuple([pgg_i[xyz_axis] for pgg_i in per_group_grads])) for xyz_axis in range(out_discr.ambient_dim)]) +# }}} + + +# {{{ divergence def _divergence_kernel(actx, out_discr, in_discr, get_diff_mat, inv_jac_mat, vec, *, metric_in_matvec): @@ -431,11 +431,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): """ vec = make_obj_array([ - tag_axes( - actx, - { i: TensorProductDOFAxis() for i in range(1,grp.dim+1) }, - fold(grp.space, vec[func_axis]) - ) + fold(grp.space, vec[func_axis]) for func_axis in range(vec.shape[0]) ]) @@ -501,6 +497,7 @@ def compute_tensor_product_div(actx, grp, diff_mat, vec, ijm): return div + # }}} @@ -542,6 +539,8 @@ def compute_simplicial_div(actx, in_grp, out_grp, get_diff_mat, vec_i, # }}} +# }}} + # {{{ Derivative operators @@ -555,6 +554,7 @@ def _reference_derivative_matrices(actx: ArrayContext, actx, _reference_derivative_matrices, lambda grp: grp.discretization_key()) def get_ref_derivative_mats(grp): + if isinstance(grp, TensorProductElementGroup): import modepy as mp import numpy.linalg as la @@ -569,14 +569,11 @@ def get_ref_derivative_mats(grp): diff_mat = actx.from_numpy(vdm_p_1d @ la.inv(vdm_1d)) from arraycontext.metadata import NameHint - return actx.freeze( - actx.tag(NameHint("tp_diff_mat_1d"), - tag_axes(actx, { - 1: DiscretizationDOFAxisTag()}, - diff_mat))) + return actx.freeze(actx.tag(NameHint("tp_diff_mat_1d"), diff_mat)) elif isinstance(grp, SimplexElementGroup): from meshmode.discretization.poly_element import diff_matrices + return actx.freeze( actx.tag_axis( 1, DiscretizationDOFAxisTag(), @@ -774,18 +771,11 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): diff_mat = la.solve(vdm.T, vdm_p.T).T stiff_1d = actx.freeze( - actx.tag_axis(1, DiscretizationDOFAxisTag(), - actx.from_numpy( - np.asarray( - diff_mat.T @ mass_1d.T)))) + actx.from_numpy(np.asarray(diff_mat.T @ mass_1d.T))) - from grudge.array_context import MassMatrix1d mass_1d = actx.freeze( - actx.tag_axis( - 1, (DiscretizationDOFAxisTag(),), - actx.from_numpy(np.asarray(mass_1d))) - ) - mass_1d = actx.tag(MassMatrix1d(), mass_1d) + actx.tag(MassMatrix1DTag(), + actx.from_numpy(np.asarray(mass_1d)))) return (stiff_1d, mass_1d) @@ -1133,18 +1123,17 @@ def get_ref_inv_mass_mat(grp): from modepy import inverse_mass_matrix if isinstance(grp, TensorProductElementGroup): + basis_1d = grp.bases_1d() nodes_1d = grp.unit_nodes_1d + inv_mass_1d = inverse_mass_matrix(basis_1d.functions, nodes_1d) + inv_mass_1d = actx.from_numpy(np.asarray(inv_mass_1d)) - from grudge.array_context import InverseMassMatrix1d - inv_mass_1d = actx.tag_axis(0, DiscretizationDOFAxisTag(), - actx.from_numpy(np.asarray(inv_mass_1d))) - inv_mass_1d = actx.freeze( - actx.tag(InverseMassMatrix1d(), inv_mass_1d)) + return actx.freeze(actx.tag(InverseMassMatrix1DTag(), inv_mass_1d)) - return inv_mass_1d elif isinstance(grp, SimplexElementGroup): + basis = grp.basis_obj() return actx.freeze( @@ -1472,33 +1461,41 @@ def face_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: # {{{ general single axis operator application -def single_axis_operator_application(actx, dim, operator, axis, data, +def single_axis_operator_application(actx, dim, operator, axis, vec, arg_names=None, tagged=None): """ Used for applying 1D operators to a single axis of a tensor of DOF data. """ - if not isinstance(arg_names, tuple): + if not isinstance(arg_names, tuple) and arg_names is not None: raise TypeError("arg_names must be a tuple.") - if not isinstance(tagged, tuple): + if not isinstance(tagged, tuple) and tagged is not None: raise TypeError("tagged must be a tuple.") + + vec = actx.tag_axis(0, DiscretizationElementAxisTag(), vec) + vec = tag_axes( + actx, + { i: TensorProductDOFAxisTag(i-1) for i in range(1, dim+1) }, + vec + ) + + # 3D grad example spec using formula below: + # x-axis (axis = 0) contraction: ij,ejop->eiop + # y-axis (axis = 1) contraction: ij,eajp->eaip + # z-axis (axis = 2) contraction: ij,eabj->eabi operator_spec = 'ij' - data_spec = f'e{"abcdefghklm"[:axis]}j{"nopqrstuvwxyz"[:dim-axis-1]}' - out_spec = f'e{"abcdefghklm"[:axis]}i{"nopqrstuvwxyz"[:dim-axis-1]}' + data_spec = f'e{"abcdefghklmn"[:axis]}j{"opqrstuvwxyz"[:dim-axis-1]}' + out_spec = f'e{"abcdefghklmn"[:axis]}i{"opqrstuvwxyz"[:dim-axis-1]}' spec = operator_spec + ',' + data_spec + '->' + out_spec - result = tag_axes( + return tag_axes( actx, - { i: TensorProductDOFAxis() for i in range(1, dim+1) }, - actx.einsum(spec, operator, data, - arg_names=arg_names, - tagged=tagged) + { i: TensorProductDOFAxisTag(i-1) for i in range(1, dim+1) }, + actx.einsum(spec, operator, vec, arg_names=arg_names, tagged=tagged) ) - return result - # }}} diff --git a/grudge/transform/metadata.py b/grudge/transform/metadata.py new file mode 100644 index 000000000..feda95b85 --- /dev/null +++ b/grudge/transform/metadata.py @@ -0,0 +1,60 @@ +__copyright__ = "Copyright (C) 2024 Addison Alvey-Blanco" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from pytools.tag import Tag, tag_dataclass +from meshmode.transform_metadata import DiscretizationDOFAxisTag + + +# {{{ tensor product specific metadata + +class OutputIsTensorProductDOFArrayOrdered(Tag): + # FIXME: REMOVE THIS + # /!\ THIS IS TEMPORARY AND WILL GO AWAY /!\ + """ + Signify that the strides will not be of order "C" or "F". + + Used to specify strides for eager einsums. + """ + pass + + +@tag_dataclass +class TensorProductDOFAxisTag(DiscretizationDOFAxisTag): + """ + Tag an axis as being an axis containing the DOFs of a tensor-product + discretization. + """ + iaxis: int + + +class MassMatrix1DTag(Tag): + """Used in DAG transformation to realize algebraic simplification of 1D + inverse mass operator times mass operator. + """ + pass + + +class InverseMassMatrix1DTag(Tag): + """See MassMatrix1d. + """ + +# }}} From 8ebccc7ae1a66272e13036b76bd7e61f3844780e Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 2 Apr 2024 12:46:20 -0500 Subject: [PATCH 61/66] Update a comment --- grudge/op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grudge/op.py b/grudge/op.py index 14f827ba1..63e90290f 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -1481,6 +1481,7 @@ def single_axis_operator_application(actx, dim, operator, axis, vec, ) # 3D grad example spec using formula below: + # assume operator is a differentiation operator # x-axis (axis = 0) contraction: ij,ejop->eiop # y-axis (axis = 1) contraction: ij,eajp->eaip # z-axis (axis = 2) contraction: ij,eabj->eabi From 29be28d4292907a2112f5f106a92fe0ad85079a6 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 2 Apr 2024 12:56:28 -0500 Subject: [PATCH 62/66] Add TensorProductOperatorAxisTag --- grudge/op.py | 20 ++++++++++++++------ grudge/transform/metadata.py | 7 ++++++- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 63e90290f..9c015a087 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -74,14 +74,16 @@ from functools import partial -from meshmode.dof_array import DOFArray, warn +from meshmode.dof_array import DOFArray from meshmode.discretization.poly_element import ( TensorProductElementGroupBase as TensorProductElementGroup, SimplexElementGroupBase as SimplexElementGroup) -from meshmode.transform_metadata import (DiscretizationAmbientDimAxisTag, FirstAxisIsElementsTag, - DiscretizationDOFAxisTag, - DiscretizationElementAxisTag, - DiscretizationFaceAxisTag) +from meshmode.transform_metadata import ( + DiscretizationAmbientDimAxisTag, + FirstAxisIsElementsTag, + DiscretizationDOFAxisTag, + DiscretizationElementAxisTag, + DiscretizationFaceAxisTag) from modepy.tools import ( reshape_array_for_tensor_product_space as fold, @@ -92,6 +94,7 @@ from grudge.transform.metadata import ( OutputIsTensorProductDOFArrayOrdered, TensorProductDOFAxisTag, + TensorProductOperatorAxisTag, MassMatrix1DTag, InverseMassMatrix1DTag ) @@ -1472,7 +1475,6 @@ def single_axis_operator_application(actx, dim, operator, axis, vec, if not isinstance(tagged, tuple) and tagged is not None: raise TypeError("tagged must be a tuple.") - vec = actx.tag_axis(0, DiscretizationElementAxisTag(), vec) vec = tag_axes( actx, @@ -1480,6 +1482,12 @@ def single_axis_operator_application(actx, dim, operator, axis, vec, vec ) + operator = tag_axes( + actx, + { i: TensorProductOperatorAxisTag() for i in range(2) }, + operator + ) + # 3D grad example spec using formula below: # assume operator is a differentiation operator # x-axis (axis = 0) contraction: ij,ejop->eiop diff --git a/grudge/transform/metadata.py b/grudge/transform/metadata.py index feda95b85..ecd753ee9 100644 --- a/grudge/transform/metadata.py +++ b/grudge/transform/metadata.py @@ -20,7 +20,7 @@ THE SOFTWARE. """ -from pytools.tag import Tag, tag_dataclass +from pytools.tag import Tag, DoNotPropagateTag, tag_dataclass from meshmode.transform_metadata import DiscretizationDOFAxisTag @@ -46,6 +46,11 @@ class TensorProductDOFAxisTag(DiscretizationDOFAxisTag): iaxis: int +@tag_dataclass +class TensorProductOperatorAxisTag(DoNotPropagateTag): + pass + + class MassMatrix1DTag(Tag): """Used in DAG transformation to realize algebraic simplification of 1D inverse mass operator times mass operator. From 607d368027e6ed5418f4335d6a9bd2085c574883 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 2 Apr 2024 12:57:47 -0500 Subject: [PATCH 63/66] Add some clarifying comments --- grudge/op.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/grudge/op.py b/grudge/op.py index 9c015a087..7cba3ce77 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -1475,6 +1475,8 @@ def single_axis_operator_application(actx, dim, operator, axis, vec, if not isinstance(tagged, tuple) and tagged is not None: raise TypeError("tagged must be a tuple.") + # {{{ ensure axes are properly tagged + vec = actx.tag_axis(0, DiscretizationElementAxisTag(), vec) vec = tag_axes( actx, @@ -1488,6 +1490,10 @@ def single_axis_operator_application(actx, dim, operator, axis, vec, operator ) + # }}} + + # {{{ einsum spec construction + # 3D grad example spec using formula below: # assume operator is a differentiation operator # x-axis (axis = 0) contraction: ij,ejop->eiop @@ -1499,6 +1505,8 @@ def single_axis_operator_application(actx, dim, operator, axis, vec, spec = operator_spec + ',' + data_spec + '->' + out_spec + # }}} + return tag_axes( actx, { i: TensorProductDOFAxisTag(i-1) for i in range(1, dim+1) }, From f03230d49b87ed747bbed26bd7de3e354c25d96f Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 2 Apr 2024 14:25:44 -0500 Subject: [PATCH 64/66] Update docs. Testing IgnoredForEqualityTag --- grudge/transform/metadata.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/grudge/transform/metadata.py b/grudge/transform/metadata.py index ecd753ee9..053209c72 100644 --- a/grudge/transform/metadata.py +++ b/grudge/transform/metadata.py @@ -20,7 +20,7 @@ THE SOFTWARE. """ -from pytools.tag import Tag, DoNotPropagateTag, tag_dataclass +from pytools.tag import IgnoredForEqualityTag, Tag, tag_dataclass from meshmode.transform_metadata import DiscretizationDOFAxisTag @@ -41,13 +41,19 @@ class OutputIsTensorProductDOFArrayOrdered(Tag): class TensorProductDOFAxisTag(DiscretizationDOFAxisTag): """ Tag an axis as being an axis containing the DOFs of a tensor-product - discretization. + discretization. Used to signify the relative update speed of an axis for + transformation (i.e. loop nest ordering) purposes. """ iaxis: int @tag_dataclass -class TensorProductOperatorAxisTag(DoNotPropagateTag): +class TensorProductOperatorAxisTag(IgnoredForEqualityTag): + """ + Signify that an axis is an operator of a tensor-product discretization. + Since these operators are reused, it is important to not propagate axis tags + along their axes. + """ pass From cadb98ab3eafd46e2bb6ce08ae94b9816f2811f7 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Wed, 3 Apr 2024 18:37:02 -0500 Subject: [PATCH 65/66] Rename some tags, change to IgnoredForPropagationTag in TP operator axis tag --- grudge/op.py | 4 ++-- grudge/transform/metadata.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/grudge/op.py b/grudge/op.py index 7cba3ce77..fa9cad155 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -95,8 +95,8 @@ OutputIsTensorProductDOFArrayOrdered, TensorProductDOFAxisTag, TensorProductOperatorAxisTag, - MassMatrix1DTag, - InverseMassMatrix1DTag + ReferenceTensorProductMassOperatorTag as MassMatrix1DTag, + ReferenceTensorProductInverseMassOperatorTag as InverseMassMatrix1DTag ) from pytools import keyed_memoize_in diff --git a/grudge/transform/metadata.py b/grudge/transform/metadata.py index 053209c72..d5634ebc0 100644 --- a/grudge/transform/metadata.py +++ b/grudge/transform/metadata.py @@ -20,7 +20,7 @@ THE SOFTWARE. """ -from pytools.tag import IgnoredForEqualityTag, Tag, tag_dataclass +from pytools.tag import IgnoredForPropagationTag, Tag, tag_dataclass from meshmode.transform_metadata import DiscretizationDOFAxisTag @@ -47,25 +47,25 @@ class TensorProductDOFAxisTag(DiscretizationDOFAxisTag): iaxis: int -@tag_dataclass -class TensorProductOperatorAxisTag(IgnoredForEqualityTag): +class TensorProductOperatorAxisTag(IgnoredForPropagationTag): """ - Signify that an axis is an operator of a tensor-product discretization. - Since these operators are reused, it is important to not propagate axis tags - along their axes. + Signify that an axis belongs to a 1D operator. No tags will be propagated + along an axis tagged with this tag. """ pass -class MassMatrix1DTag(Tag): - """Used in DAG transformation to realize algebraic simplification of 1D +class ReferenceTensorProductMassOperatorTag(Tag): + """ + Used in DAG transformation to realize algebraic simplification of 1D inverse mass operator times mass operator. """ pass -class InverseMassMatrix1DTag(Tag): - """See MassMatrix1d. +class ReferenceTensorProductInverseMassOperatorTag(Tag): + """ + See MassMatrix1d. """ # }}} From a662445c43d38d97324cb9744d144e0fa9a47a42 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Tue, 9 Apr 2024 11:10:26 -0500 Subject: [PATCH 66/66] Small change --- examples/tensor-product-examples/acoustic_pulse.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/tensor-product-examples/acoustic_pulse.py b/examples/tensor-product-examples/acoustic_pulse.py index 9eb1b1ce7..9520945f3 100644 --- a/examples/tensor-product-examples/acoustic_pulse.py +++ b/examples/tensor-product-examples/acoustic_pulse.py @@ -160,8 +160,6 @@ def run_acoustic_pulse(actx, } ) - print(actx) - # }}} # {{{ Euler operator