From cc90f96c70eeb7252a0cd3e0a17a48ae52422389 Mon Sep 17 00:00:00 2001
From: Jesse Chan <1156048+jlchan@users.noreply.github.com>
Date: Mon, 1 Jul 2024 09:04:16 -0500
Subject: [PATCH] Add collapsed coordinate quadrature for non-quad/hex elements
 (#176)

* add tensor product quadrature options for tri, tet, pyr

* add tensor product quadrature for wedge

* add tests

* docstrings
---
 src/RefElemData_polynomial.jl | 48 ++++++++++++++++++++++++++++++-----
 test/reference_elem_tests.jl  |  6 +++++
 2 files changed, 47 insertions(+), 7 deletions(-)

diff --git a/src/RefElemData_polynomial.jl b/src/RefElemData_polynomial.jl
index 748f4a03..c941b421 100644
--- a/src/RefElemData_polynomial.jl
+++ b/src/RefElemData_polynomial.jl
@@ -17,13 +17,6 @@ RefElemData(elem::Line, approx_type::Polynomial{<:TensorProductQuadrature}, N; k
     RefElemData(elem, Polynomial{MultidimensionalQuadrature}(), N; 
                 quad_rule_vol=approx_type.data.quad_rule_1D, kwargs...)
 
-function RefElemData(elem::Union{Tri, Tet, Wedge, Pyr}, 
-            approx_type::Polynomial{<:TensorProductQuadrature}, 
-            N; kwargs...)
-    error("Tensor product quadrature constructors not yet implemented " * 
-          "for Tri, Tet, Wedge, Pyr elements.")
-end
-
 """
     RefElemData(elem::Line, approximation_type, N;
                 quad_rule_vol = quad_nodes(elem, N+1))
@@ -473,6 +466,47 @@ function tensor_product_quadrature(::Union{Tet, Hex}, r1D, w1D)
     return rq, sq, tq, wq
 end
 
+"""
+    RefElemData(elem::Union{Tri, Tet, Pyr}, approx_type::Polynomial{<:TensorProductQuadrature}, N; kwargs...)
+    RefElemData(elem::Union{Wedge}, 
+                     approx_type::Polynomial{<:TensorProductQuadrature}, N; 
+                     quad_rule_tri = stroud_quad_nodes(Tri(), 2 * N),
+                     quad_rule_line = gauss_quad(0, 0, N),
+                     kwargs...)
+
+Uses collapsed coordinate volume quadrature. Should be called via
+```julia
+RefElemData(Tri(), Polynomial(TensorProductQuadrature()), N)
+```
+"""
+function RefElemData(elem::Union{Tri, Tet, Pyr}, approx_type::Polynomial{<:TensorProductQuadrature}, N; kwargs...)
+    rd = RefElemData(elem, Polynomial{MultidimensionalQuadrature}(), N; 
+                     quad_rule_vol=stroud_quad_nodes(elem, 2 * N), kwargs...)
+    @set rd.approximation_type = approx_type
+    return rd
+end
+
+function RefElemData(elem::Union{Wedge}, 
+                     approx_type::Polynomial{<:TensorProductQuadrature}, N; 
+                     quad_rule_tri = stroud_quad_nodes(Tri(), 2 * N),
+                     quad_rule_line = gauss_quad(0, 0, N),
+                     kwargs...)
+
+    rq_tri, sq_tri, wq_tri = quad_rule_tri
+    rq_1D, wq_1D = quad_rule_line
+    rq = repeat(rq_tri, length(rq_1D))
+    sq = repeat(sq_tri, length(rq_1D))
+    tq = repeat(rq_1D, length(rq_tri))
+    wq = repeat(wq_tri, length(wq_1D)) .* repeat(wq_1D, length(wq_tri))
+    quad_rule_vol = (rq, sq, tq, wq)
+
+    rd = RefElemData(elem, Polynomial{MultidimensionalQuadrature}(), N; 
+                     quad_rule_vol, kwargs...)
+    @set rd.approximation_type = approx_type
+    return rd          
+end
+
+
 """
     RefElemData(elem::Union{Line, Quad, Hex}, approximation_type::Polynomial{Gauss}, N)
 
diff --git a/test/reference_elem_tests.jl b/test/reference_elem_tests.jl
index 46f7a419..a0f56ffd 100644
--- a/test/reference_elem_tests.jl
+++ b/test/reference_elem_tests.jl
@@ -221,6 +221,12 @@
 
         @test StartUpDG._short_typeof(rd.element_type) == "Pyr"
     end
+
+    @testset "Collapsed coordinate quadratures" for elem in [Tri(), Tet(), Wedge(), Pyr()]
+        rd = RefElemData(elem, Polynomial(), 2)
+        rd_tp = RefElemData(elem, Polynomial(TensorProductQuadrature()), 2)
+        @test rd.M ≈ rd_tp.M
+    end
 end
 
 inverse_trace_constant_compare(rd::RefElemData{3, <:Wedge, <:TensorProductWedge}) =