Skip to content

Commit

Permalink
remove apply_default_restrictions()
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Nov 27, 2024
1 parent 08351d0 commit 8334f62
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 81 deletions.
4 changes: 2 additions & 2 deletions test/test_apply_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
i,
triangle,
)
from ufl.algorithms.apply_restrictions import apply_default_restrictions, apply_restrictions
from ufl.algorithms.apply_restrictions import apply_restrictions
from ufl.algorithms.renumbering import renumber_indices
from ufl.finiteelement import FiniteElement
from ufl.pullback import identity_pullback
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_apply_restrictions():
assert apply_restrictions((grad(f) + grad(g))("-")) == (grad(f)("-") + grad(g)("-"))

# x is the same from both sides but computed from one of them
assert apply_default_restrictions(x) == x("+")
assert apply_restrictions(x) == x("+")

# n on a linear mesh is opposite pointing from the other side
assert apply_restrictions(n("+")) == n("+")
Expand Down
99 changes: 26 additions & 73 deletions ufl/algorithms/apply_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@
class RestrictionPropagator(MultiFunction):
"""Restriction propagator."""

def __init__(self, side=None):
def __init__(self, side=None, apply_default=True):
"""Initialise."""
MultiFunction.__init__(self)
self.current_restriction = side
self.default_restriction = "+"
self.apply_default = apply_default
# Caches for propagating the restriction with map_expr_dag
self.vcaches = {"+": {}, "-": {}}
self.rcaches = {"+": {}, "-": {}}
if self.current_restriction is None:
self._rp = {"+": RestrictionPropagator("+"), "-": RestrictionPropagator("-")}
self._rp = {
"+": RestrictionPropagator(side="+", apply_default=apply_default),
"-": RestrictionPropagator(side="-", apply_default=apply_default),
}

def restricted(self, o):
"""When hitting a restricted quantity, visit child with a separate restriction algorithm."""
Expand Down Expand Up @@ -64,9 +68,12 @@ def _require_restriction(self, o):
def _default_restricted(self, o):
"""Restrict a continuous quantity to default side if no current restriction is set."""
r = self.current_restriction
if r is None:
r = self.default_restriction
return o(r)
if r is not None:
return o(r)
if self.apply_default:
return o(self.default_restriction)
else:
return o

def _opposite(self, o):
"""Restrict a quantity to default side.
Expand Down Expand Up @@ -139,6 +146,18 @@ def reference_value(self, o):
reference_cell_volume = _ignore_restriction
reference_facet_volume = _ignore_restriction

# These are the same from either side but to compute them
# cell (or facet) data from one side must be selected:
spatial_coordinate = _default_restricted
# Depends on cell only to get to the facet:
facet_jacobian = _default_restricted
facet_jacobian_determinant = _default_restricted
facet_jacobian_inverse = _default_restricted
facet_area = _default_restricted
min_facet_edge_length = _default_restricted
max_facet_edge_length = _default_restricted
facet_origin = _default_restricted # FIXME: Is this valid for quads?

def coefficient(self, o):
"""Restrict a coefficient.
Expand Down Expand Up @@ -174,76 +193,10 @@ def facet_normal(self, o):
return self._require_restriction(o)


def apply_restrictions(expression):
def apply_restrictions(expression, apply_default=True):
"""Propagate restriction nodes to wrap differential terminals directly."""
integral_types = [
k for k in integral_type_to_measure_name.keys() if k.startswith("interior_facet")
]
rules = RestrictionPropagator()
return map_integrand_dags(rules, expression, only_integral_type=integral_types)


class DefaultRestrictionApplier(MultiFunction):
"""Default restriction applier."""

def __init__(self, side=None):
"""Initialise."""
MultiFunction.__init__(self)
self.current_restriction = side
self.default_restriction = "+"
if self.current_restriction is None:
self._rp = {"+": DefaultRestrictionApplier("+"), "-": DefaultRestrictionApplier("-")}

def terminal(self, o):
"""Apply to terminal."""
# Most terminals are unchanged
return o

# Default: Operators should reconstruct only if subtrees are not touched
operator = MultiFunction.reuse_if_untouched

def restricted(self, o):
"""Apply to restricted."""
# Don't restrict twice
return o

def derivative(self, o):
"""Apply to derivative."""
# I don't think it's safe to just apply default restriction
# to the argument of any derivative, i.e. grad(cg1_function)
# is not continuous across cells even if cg1_function is.
return o

def _default_restricted(self, o):
"""Restrict a continuous quantity to default side if no current restriction is set."""
r = self.current_restriction
if r is None:
r = self.default_restriction
return o(r)

# These are the same from either side but to compute them
# cell (or facet) data from one side must be selected:
spatial_coordinate = _default_restricted
# Depends on cell only to get to the facet:
facet_jacobian = _default_restricted
facet_jacobian_determinant = _default_restricted
facet_jacobian_inverse = _default_restricted
# facet_tangents = _default_restricted
# facet_midpoint = _default_restricted
facet_area = _default_restricted
# facet_diameter = _default_restricted
min_facet_edge_length = _default_restricted
max_facet_edge_length = _default_restricted
facet_origin = _default_restricted # FIXME: Is this valid for quads?


def apply_default_restrictions(expression):
"""Some terminals can be restricted from either side.
This applies a default restriction to such terminals if unrestricted.
"""
integral_types = [
k for k in integral_type_to_measure_name.keys() if k.startswith("interior_facet")
]
rules = DefaultRestrictionApplier()
rules = RestrictionPropagator(apply_default=apply_default)
return map_integrand_dags(rules, expression, only_integral_type=integral_types)
8 changes: 2 additions & 6 deletions ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ufl.algorithms.apply_function_pullbacks import apply_function_pullbacks
from ufl.algorithms.apply_geometry_lowering import apply_geometry_lowering
from ufl.algorithms.apply_integral_scaling import apply_integral_scaling
from ufl.algorithms.apply_restrictions import apply_default_restrictions, apply_restrictions
from ufl.algorithms.apply_restrictions import apply_restrictions
from ufl.algorithms.check_arities import check_form_arity
from ufl.algorithms.comparison_checker import do_comparison_check

Expand Down Expand Up @@ -306,10 +306,6 @@ def compute_form_data(
if do_apply_integral_scaling:
form = apply_integral_scaling(form)

# Apply default restriction to fully continuous terminals
if do_apply_default_restrictions:
form = apply_default_restrictions(form)

# Lower abstractions for geometric quantities into a smaller set
# of quantities, allowing the form compiler to deal with a smaller
# set of types and treating geometric quantities like any other
Expand All @@ -334,7 +330,7 @@ def compute_form_data(

# Propagate restrictions to terminals
if do_apply_restrictions:
form = apply_restrictions(form)
form = apply_restrictions(form, apply_default=do_apply_default_restrictions)

# If in real mode, remove any complex nodes introduced during form processing.
if not complex_mode:
Expand Down

0 comments on commit 8334f62

Please sign in to comment.