Skip to content

Commit

Permalink
WIP: enable macro-quadrature
Browse files Browse the repository at this point in the history
  • Loading branch information
rckirby committed Apr 3, 2024
1 parent 23ae55f commit be216cc
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,24 @@
import collections
import string
import operator
import string
from functools import reduce
from itertools import chain, product

import gem
import gem.impero_utils as impero_utils
import numpy
from numpy import asarray

from ufl.utils.sequences import max_degree

from FIAT.reference_element import TensorProductCell

from finat.quadrature import AbstractQuadratureRule, make_quadrature

import gem

from gem.node import traversal
from gem.optimise import constant_fold_zero
from gem.optimise import remove_componenttensors as prune
from gem.utils import cached_property
import gem.impero_utils as impero_utils
from gem.optimise import remove_componenttensors as prune, constant_fold_zero

from numpy import asarray
from tsfc import fem, ufl_utils
from tsfc.finatinterface import as_fiat_cell, convert, create_element
from tsfc.kernel_interface import KernelInterface
from tsfc.finatinterface import as_fiat_cell, create_element
from tsfc.logging import logger
from ufl.utils.sequences import max_degree


class KernelBuilderBase(KernelInterface):
Expand Down Expand Up @@ -301,7 +296,8 @@ def set_quad_rule(params, cell, integral_type, functions):
quadrature_degree = params["quadrature_degree"]
except KeyError:
quadrature_degree = params["estimated_polynomial_degree"]
function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions]
function_degrees = [f.ufl_function_space().ufl_element().degree()
for f in functions]
if all((asarray(quadrature_degree) > 10 * asarray(degree)).all()
for degree in function_degrees):
logger.warning("Estimated quadrature degree %s more "
Expand All @@ -314,9 +310,14 @@ def set_quad_rule(params, cell, integral_type, functions):
quad_rule = params["quadrature_rule"]
except KeyError:
fiat_cell = as_fiat_cell(cell)
finat_elements = set(create_element(f.ufl_element()) for f in functions)
print(list(f.degree for f in finat_elements))

fiat_cells = [fiat_cell] + [finat_el.complex for finat_el in finat_elements]

integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
integration_cell = fiat_cell.construct_subelement(integration_dim)
quad_rule = make_quadrature(integration_cell, quadrature_degree)

quad_rule = make_quadrature(fiat_cells, quadrature_degree, dim=integration_dim)
params["quadrature_rule"] = quad_rule

if not isinstance(quad_rule, AbstractQuadratureRule):
Expand Down

0 comments on commit be216cc

Please sign in to comment.