Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #675 from OP2/ksagiyam/periodic_extrusion
Browse files Browse the repository at this point in the history
Ksagiyam/periodic extrusion
  • Loading branch information
ksagiyam authored Nov 15, 2022
2 parents 382a718 + e8722fb commit 9de5afc
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 67 deletions.
92 changes: 65 additions & 27 deletions pyop2/codegen/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from loopy.types import OpaqueType
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg)
from pyop2.codegen.representation import (Accumulate, Argument, Comparison,
from pyop2.codegen.representation import (Accumulate, Argument, Comparison, Conditional,
DummyInstruction, Extent, FixedIndex,
FunctionCall, Index, Indexed,
KernelInst, Literal, LogicalAnd,
Expand All @@ -28,19 +28,27 @@ def __init__(self):
super().__init__(name="Mat")


def _Remainder(a, b):
# ad hoc replacement of Remainder()
# Replace this with Remainder(a, b) once it gets fixed.
return Conditional(Comparison("<", a, b), a, Sum(a, Product(Literal(numpy.int32(-1)), b)))


class Map(object):

__slots__ = ("values", "offset", "interior_horizontal",
"variable", "unroll", "layer_bounds",
__slots__ = ("values", "extruded_periodic", "offset", "offset_quotient", "interior_horizontal",
"variable", "unroll", "layer_bounds", "num_layers",
"prefetch", "_pmap_count")

def __init__(self, interior_horizontal, layer_bounds,
def __init__(self, interior_horizontal, layer_bounds, num_layers,
arity, dtype,
offset=None, unroll=False,
extruded=False, constant_layers=False):
offset=None, offset_quotient=None, unroll=False,
extruded=False, extruded_periodic=False, constant_layers=False):
self.variable = extruded and not constant_layers
self.extruded_periodic = extruded_periodic
self.unroll = unroll
self.layer_bounds = layer_bounds
self.num_layers = num_layers
self.interior_horizontal = interior_horizontal
self.prefetch = {}

Expand All @@ -53,9 +61,14 @@ def __init__(self, interior_horizontal, layer_bounds,
offset = Literal(offset[0], casting=True)
else:
offset = NamedLiteral(offset, parent=values, suffix="offset")
if offset_quotient is not None:
assert type(offset_quotient) == tuple
offset_quotient = numpy.array(offset_quotient, dtype=numpy.int32)
offset_quotient = NamedLiteral(offset_quotient, parent=values, suffix="offset_quotient")

self.values = values
self.offset = offset
self.offset_quotient = offset_quotient
self._pmap_count = itertools.count()

@property
Expand Down Expand Up @@ -87,18 +100,29 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
if key is None:
key = 1
if key not in self.prefetch:
# See comments in "sparsity.pyx".
bottom_layer, _ = self.layer_bounds
k = Index(f.extent if f.extent is not None else 1)
offset = Sum(Sum(layer, Product(Literal(numpy.int32(-1)), bottom_layer)), k)
j = Index()
# Inline map offsets where all entries are identical.
if self.offset.shape == ():
offset = Product(offset, self.offset)
else:
offset = Product(offset, Indexed(self.offset, (j,)))
base = Indexed(base, (j, ))
unit_offset = self.offset if self.offset.shape == () else Indexed(self.offset, (j,))
if self.extruded_periodic:
if self.offset_quotient is None:
# Equivalent to offset_quotient[:] == 0.
# Avoid unnecessary logic below.
offset = _Remainder(offset, self.num_layers)
else:
effective_offset = Sum(offset, Indexed(self.offset_quotient, (j,)))
# The following code currently does not work: "undefined symbol: loopy_mod_int32"
# offset = Remainder(effective_offset, self.num_layers)
# Use less elegant and less robust way for now.
offset = Sum(_Remainder(effective_offset, self.num_layers),
Product(Literal(numpy.int32(-1)),
_Remainder(Indexed(self.offset_quotient, (j,)), self.num_layers)))
# Inline map offsets where all entries are identical.
offset = Product(unit_offset, offset)
self.prefetch[key] = Materialise(PackInst(), Sum(base, offset), MultiIndex(k, j))

return Indexed(self.prefetch[key], (f, i)), (f, i)
else:
assert f.extent == 1 or f.extent is None
Expand All @@ -125,8 +149,10 @@ class PMap(Map):
def __init__(self, map_, permutation):
# Copy over properties
self.variable = map_.variable
self.extruded_periodic = map_.extruded_periodic
self.unroll = map_.unroll
self.layer_bounds = map_.layer_bounds
self.num_layers = map_.num_layers
self.interior_horizontal = map_.interior_horizontal
self.prefetch = {}
self.values = map_.values
Expand All @@ -143,6 +169,7 @@ def __init__(self, map_, permutation):
else:
offset = map_.offset
self.offset = offset
self.offset_quotient = map_.offset_quotient
self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}")

def indexed(self, multiindex, layer=None):
Expand Down Expand Up @@ -644,7 +671,7 @@ def emit_unpack_instruction(self, *,

class WrapperBuilder(object):

def __init__(self, *, kernel, subset, extruded, constant_layers, iteration_region=None, single_cell=False,
def __init__(self, *, kernel, subset, extruded, extruded_periodic, constant_layers, iteration_region=None, single_cell=False,
pass_layer_to_kernel=False, forward_arg_types=()):
self.kernel = kernel
self.local_knl_args = iter(kernel.arguments)
Expand All @@ -655,6 +682,7 @@ def __init__(self, *, kernel, subset, extruded, constant_layers, iteration_regio
self.maps = OrderedDict()
self.subset = subset
self.extruded = extruded
self.extruded_periodic = extruded_periodic
self.constant_layers = constant_layers
if iteration_region is None:
self.iteration_region = ALL
Expand Down Expand Up @@ -700,6 +728,14 @@ def _layers_array(self):
else:
return Argument((None, 2), IntType, name="layers")

@cached_property
def num_layers(self):
cellStart = Indexed(self._layers_array, (self._layer_index, FixedIndex(0)))
cellEnd = Sum(Indexed(self._layers_array, (self._layer_index, FixedIndex(1))), Literal(IntType.type(-1)))
n = Sum(cellEnd,
Product(Literal(numpy.int32(-1)), cellStart))
return Materialise(PackInst(), n, MultiIndex())

@cached_property
def bottom_layer(self):
if self.iteration_region == ON_TOP:
Expand All @@ -723,23 +759,23 @@ def top_layer(self):

@cached_property
def layer_extents(self):
cellStart = Indexed(self._layers_array, (self._layer_index, FixedIndex(0)))
cellEnd = Sum(Indexed(self._layers_array, (self._layer_index, FixedIndex(1))), Literal(IntType.type(-1)))
if self.iteration_region == ON_BOTTOM:
start = Indexed(self._layers_array, (self._layer_index, FixedIndex(0)))
end = Sum(Indexed(self._layers_array, (self._layer_index, FixedIndex(0))),
Literal(IntType.type(1)))
start = cellStart
end = Sum(cellStart, Literal(IntType.type(1)))
elif self.iteration_region == ON_TOP:
start = Sum(Indexed(self._layers_array, (self._layer_index, FixedIndex(1))),
Literal(IntType.type(-2)))
end = Sum(Indexed(self._layers_array, (self._layer_index, FixedIndex(1))),
Literal(IntType.type(-1)))
start = Sum(cellEnd, Literal(IntType.type(-1)))
end = cellEnd
elif self.iteration_region == ON_INTERIOR_FACETS:
start = Indexed(self._layers_array, (self._layer_index, FixedIndex(0)))
end = Sum(Indexed(self._layers_array, (self._layer_index, FixedIndex(1))),
Literal(IntType.type(-2)))
start = cellStart
if self.extruded_periodic:
end = cellEnd
else:
end = Sum(cellEnd, Literal(IntType.type(-1)))
elif self.iteration_region == ALL:
start = Indexed(self._layers_array, (self._layer_index, FixedIndex(0)))
end = Sum(Indexed(self._layers_array, (self._layer_index, FixedIndex(1))),
Literal(IntType.type(-1)))
start = cellStart
end = cellEnd
else:
raise ValueError("Unknown iteration region")
return (Materialise(PackInst(), start, MultiIndex()),
Expand Down Expand Up @@ -862,9 +898,11 @@ def _add_map(self, map_, unroll=False):
else:
map_ = Map(interior_horizontal,
(self.bottom_layer, self.top_layer),
arity=map_.arity, offset=map_.offset, dtype=IntType,
self.num_layers,
arity=map_.arity, offset=map_.offset, offset_quotient=map_.offset_quotient, dtype=IntType,
unroll=unroll,
extruded=self.extruded,
extruded_periodic=self.extruded_periodic,
constant_layers=self.constant_layers)
self.maps[key] = map_
return map_
Expand Down
23 changes: 16 additions & 7 deletions pyop2/codegen/rep2loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
LogicalNot, LogicalAnd, LogicalOr,
Materialise, Accumulate, FunctionCall, When,
Argument, Variable, Literal, NamedLiteral,
Symbol, Zero, Sum, Min, Max, Product)
Symbol, Zero, Sum, Min, Max, Product,
Quotient, FloorDiv, Remainder)
from pyop2.codegen.representation import (PackInst, UnpackInst, KernelInst, PreUnpackInst)
from pytools import ImmutableRecord
from pyop2.codegen.loopycompat import _match_caller_callee_argument_dimension_
Expand Down Expand Up @@ -853,18 +854,26 @@ def expression_uop(expr, parameters):

@expression.register(Sum)
@expression.register(Product)
@expression.register(Quotient)
@expression.register(FloorDiv)
@expression.register(Remainder)
@expression.register(LogicalAnd)
@expression.register(LogicalOr)
@expression.register(BitwiseAnd)
@expression.register(BitwiseOr)
def expression_binop(expr, parameters):
children = tuple(expression(c, parameters) for c in expr.children)
return {Sum: pym.Sum,
Product: pym.Product,
LogicalOr: pym.LogicalOr,
LogicalAnd: pym.LogicalAnd,
BitwiseOr: pym.BitwiseOr,
BitwiseAnd: pym.BitwiseAnd}[type(expr)](children)
if type(expr) in {Quotient, FloorDiv, Remainder}:
return {Quotient: pym.Quotient,
FloorDiv: pym.FloorDiv,
Remainder: pym.Remainder}[type(expr)](*children)
else:
return {Sum: pym.Sum,
Product: pym.Product,
LogicalOr: pym.LogicalOr,
LogicalAnd: pym.LogicalAnd,
BitwiseOr: pym.BitwiseOr,
BitwiseAnd: pym.BitwiseAnd}[type(expr)](children)


@expression.register(Min)
Expand Down
26 changes: 26 additions & 0 deletions pyop2/codegen/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,32 @@ def dtype(self):
return numpy.find_common_type([], [a.dtype, b.dtype])


class QuotientBase(Scalar):
__slots__ = ("children", )

def __init__(self, a, b):
assert not a.shape
assert not b.shape
self.children = a, b

@cached_property
def dtype(self):
a, b = self.children
return numpy.find_common_type([], [a.dtype, b.dtype])


class Quotient(QuotientBase):
pass


class FloorDiv(QuotientBase):
pass


class Remainder(QuotientBase):
pass


class Indexed(Scalar):
__slots__ = ("children", )

Expand Down
9 changes: 8 additions & 1 deletion pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ class MapKernelArg:

arity: int
offset: Optional[Tuple[int, ...]] = None
offset_quotient: Optional[Tuple[int, ...]] = None

def __post_init__(self):
if not isinstance(self.offset, collections.abc.Hashable):
raise ValueError("The provided offset must be hashable")
if not isinstance(self.offset_quotient, collections.abc.Hashable):
raise ValueError("The provided offset_quotient must be hashable")

@property
def cache_key(self):
return type(self), self.arity, self.offset
return type(self), self.arity, self.offset, self.offset_quotient


@dataclass(eq=False, frozen=True)
Expand Down Expand Up @@ -231,6 +234,7 @@ class GlobalKernel(Cached):
:param arguments: An iterable of :class:`KernelArg` instances describing
the arguments to the global kernel.
:param extruded: Are we looping over an extruded mesh?
:param extruded_periodic: Flag for periodic extrusion.
:param constant_layers: If looping over an extruded mesh, are the layers the
same for each base entity?
:param subset: Are we iterating over a subset?
Expand Down Expand Up @@ -264,6 +268,7 @@ def _cache_key(cls, local_knl, arguments, **kwargs):

def __init__(self, local_kernel, arguments, *,
extruded=False,
extruded_periodic=False,
constant_layers=False,
subset=False,
iteration_region=None,
Expand All @@ -283,6 +288,7 @@ def __init__(self, local_kernel, arguments, *,
self.local_kernel = local_kernel
self.arguments = arguments
self._extruded = extruded
self._extruded_periodic = extruded_periodic
self._constant_layers = constant_layers
self._subset = subset
self._iteration_region = iteration_region
Expand Down Expand Up @@ -334,6 +340,7 @@ def builder(self):
builder = WrapperBuilder(kernel=self.local_kernel,
subset=self._subset,
extruded=self._extruded,
extruded_periodic=self._extruded_periodic,
constant_layers=self._constant_layers,
iteration_region=self._iteration_region,
pass_layer_to_kernel=self._pass_layer_arg)
Expand Down
3 changes: 3 additions & 0 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,12 @@ def LegacyParloop(local_knl, iterset, *args, **kwargs):

global_knl_args = tuple(a.global_kernel_arg for a in args)
extruded = iterset._extruded
extruded_periodic = iterset._extruded_periodic
constant_layers = extruded and iterset.constant_layers
subset = isinstance(iterset, Subset)
global_knl = GlobalKernel(local_knl, global_knl_args,
extruded=extruded,
extruded_periodic=extruded_periodic,
constant_layers=constant_layers,
subset=subset,
**kwargs)
Expand Down Expand Up @@ -673,6 +675,7 @@ def generate_single_cell_wrapper(iterset, args, forward_args=(),
builder = WrapperBuilder(kernel=empty_knl,
subset=isinstance(iterset, Subset),
extruded=iterset._extruded,
extruded_periodic=iterset._extruded_periodic,
constant_layers=iterset._extruded and iterset.constant_layers,
single_cell=True,
forward_arg_types=forward_arg_types)
Expand Down
Loading

0 comments on commit 9de5afc

Please sign in to comment.