Skip to content

Commit

Permalink
Merge pull request #621 from OP2/wence/feature/composed-map
Browse files Browse the repository at this point in the history
Permuted maps
  • Loading branch information
dham authored Aug 11, 2021
2 parents b48f65d + a257d1d commit 9230ed6
Show file tree
Hide file tree
Showing 13 changed files with 151 additions and 99 deletions.
35 changes: 35 additions & 0 deletions pyop2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2675,6 +2675,41 @@ def __le__(self, o):
return self == o


class PermutedMap(Map):
"""Composition of a standard :class:`Map` with a constant permutation.
:arg map_: The map to permute.
:arg permutation: The permutation of the map indices.
Where normally staging to element data is performed as
.. code-block::
local[i] = global[map[i]]
With a :class:`PermutedMap` we instead get
.. code-block::
local[i] = global[map[permutation[i]]]
This might be useful if your local kernel wants data in a
different order to the one that the map provides, and you don't
want two global-sized data structures.
"""
def __init__(self, map_, permutation):
self.map_ = map_
self.permutation = np.asarray(permutation, dtype=Map.dtype)
assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all()

@cached_property
def _wrapper_cache_key_(self):
return super()._wrapper_cache_key_ + (tuple(self.permutation),)

def __getattr__(self, name):
return getattr(self.map_, name)


class MixedMap(Map, ObjectCached):
r"""A container for a bag of :class:`Map`\s."""

Expand Down
22 changes: 17 additions & 5 deletions pyop2/codegen/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
When, Zero)
from pyop2.datatypes import IntType
from pyop2.op2 import (ALL, INC, MAX, MIN, ON_BOTTOM, ON_INTERIOR_FACETS,
ON_TOP, READ, RW, WRITE, Subset)
ON_TOP, READ, RW, WRITE, Subset, PermutedMap)
from pyop2.utils import cached_property


Expand All @@ -30,7 +30,7 @@ class Map(object):

__slots__ = ("values", "offset", "interior_horizontal",
"variable", "unroll", "layer_bounds",
"prefetch")
"prefetch", "permutation")

def __init__(self, map_, interior_horizontal, layer_bounds,
values=None, offset=None, unroll=False):
Expand All @@ -50,11 +50,17 @@ def __init__(self, map_, interior_horizontal, layer_bounds,
offset = map_.offset
shape = (None, ) + map_.shape[1:]
values = Argument(shape, dtype=map_.dtype, pfx="map")
if isinstance(map_, PermutedMap):
self.permutation = NamedLiteral(map_.permutation, parent=values, suffix="permutation")
if offset is not None:
offset = offset[map_.permutation]
else:
self.permutation = None
if offset is not None:
if len(set(map_.offset)) == 1:
offset = Literal(offset[0], casting=True)
else:
offset = NamedLiteral(offset, name=values.name + "_offset")
offset = NamedLiteral(offset, parent=values, suffix="offset")

self.values = values
self.offset = offset
Expand All @@ -76,7 +82,10 @@ def indexed(self, multiindex, layer=None):
base_key = None
if base_key not in self.prefetch:
j = Index()
base = Indexed(self.values, (n, j))
if self.permutation is None:
base = Indexed(self.values, (n, j))
else:
base = Indexed(self.values, (n, Indexed(self.permutation, (j,))))
self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j))

base = self.prefetch[base_key]
Expand All @@ -103,7 +112,10 @@ def indexed(self, multiindex, layer=None):
return Indexed(self.prefetch[key], (f, i)), (f, i)
else:
assert f.extent == 1 or f.extent is None
base = Indexed(self.values, (n, i))
if self.permutation is None:
base = Indexed(self.values, (n, i))
else:
base = Indexed(self.values, (n, Indexed(self.permutation, (i,))))
return base, (f, i)

def indexed_vector(self, n, shape, layer=None):
Expand Down
30 changes: 8 additions & 22 deletions pyop2/codegen/optimise.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from pyop2.codegen.node import traversal, reuse_if_untouched, Memoizer
from functools import singledispatch
from pyop2.codegen.representation import (Index, RuntimeIndex, Node,
FunctionCall, Variable, Argument,
NamedLiteral)
FunctionCall, Variable, Argument)


def collect_indices(expressions):
Expand Down Expand Up @@ -90,7 +89,7 @@ def index_merger(instructions, cache=None):

@singledispatch
def _rename_node(node, self):
"""Replace division with multiplication
"""Rename nodes
:param node: root of expression
:param self: function for recursive calls
Expand All @@ -103,7 +102,7 @@ def _rename_node(node, self):

@_rename_node.register(Index)
def _rename_node_index(node, self):
name = self.replace.get(node, node.name)
name = self.renamer(node)
return Index(extent=node.extent, name=name)


Expand All @@ -114,38 +113,25 @@ def _rename_node_func(node, self):
return FunctionCall(node.name, node.label, node.access, free_indices, *children)


@_rename_node.register(RuntimeIndex)
def _rename_node_rtindex(node, self):
children = tuple(map(self, node.children))
name = self.replace.get(node, node.name)
return RuntimeIndex(*children, name=name)


@_rename_node.register(NamedLiteral)
def _rename_node_namedliteral(node, self):
name = self.replace.get(node, node.name)
return NamedLiteral(node.value, name)


@_rename_node.register(Variable)
def _rename_node_variable(node, self):
name = self.replace.get(node, node.name)
name = self.renamer(node)
return Variable(name, node.shape, node.dtype)


@_rename_node.register(Argument)
def _rename_node_argument(node, self):
name = self.replace.get(node, node.name)
name = self.renamer(node)
return Argument(node.shape, node.dtype, name=name)


def rename_nodes(instructions, replace):
def rename_nodes(instructions, renamer):
"""Rename the nodes in the instructions.
:param instructions: Iterable of nodes.
:param replace: Dictionary matching old names to new names.
:param renamer: Function that maps nodes to new names
:return: List of instructions with nodes renamed.
"""
mapper = Memoizer(_rename_node)
mapper.replace = replace
mapper.renamer = renamer
return list(map(mapper, instructions))
46 changes: 28 additions & 18 deletions pyop2/codegen/rep2loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from collections import OrderedDict, defaultdict
from functools import singledispatch, reduce
import itertools
import re
import operator

from pyop2.codegen.node import traversal, Node, Memoizer, reuse_if_untouched
Expand Down Expand Up @@ -430,24 +429,35 @@ def generate(builder, wrapper_name=None):
instructions = instructions + initialiser
mapper.initialisers = [tuple(merger(i) for i in inits) for inits in mapper.initialisers]

def name_generator(prefix):
yield from (f"{prefix}{i}" for i in itertools.count())

# rename indices and nodes (so that the counters start from zero)
pattern = re.compile(r"^([a-zA-Z_]+)([0-9]+)(_offset)?$")
replacements = {}
counter = defaultdict(itertools.count)
for node in traversal(instructions):
if isinstance(node, (Index, RuntimeIndex, Variable, Argument, NamedLiteral)):
match = pattern.match(node.name)
if match is None:
continue
prefix, _, postfix = match.groups()
if postfix is None:
postfix = ""
replacements[node] = "%s%d%s" % (prefix, next(counter[(prefix, postfix)]), postfix)

instructions = rename_nodes(instructions, replacements)
mapper.initialisers = [rename_nodes(inits, replacements) for inits in mapper.initialisers]
parameters.wrapper_arguments = rename_nodes(parameters.wrapper_arguments, replacements)
s, e = rename_nodes([mapper(e) for e in builder.layer_extents], replacements)
node_names = {}
node_namers = dict((cls, name_generator(prefix))
for cls, prefix in [(Index, "i"), (Variable, "t")])

def renamer(expr):
if isinstance(expr, Argument):
if expr._name is not None:
# Some arguments have given names
return expr._name
else:
# Otherwise generate one with their given prefix.
namer = node_namers.setdefault((type(expr), expr.prefix),
name_generator(expr.prefix))
else:
namer = node_namers[type(expr)]
try:
return node_names[expr]
except KeyError:
return node_names.setdefault(expr, next(namer))

instructions = rename_nodes(instructions, renamer)
mapper.initialisers = [rename_nodes(inits, renamer)
for inits in mapper.initialisers]
parameters.wrapper_arguments = rename_nodes(parameters.wrapper_arguments, renamer)
s, e = rename_nodes([mapper(e) for e in builder.layer_extents], renamer)
parameters.layer_start = s.name
parameters.layer_end = e.name

Expand Down
48 changes: 31 additions & 17 deletions pyop2/codegen/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class IndexBase(metaclass=ABCMeta):

class Index(Terminal, Scalar):
_count = itertools.count()
__slots__ = ("name", "extent", "merge")
__front__ = ("name", "extent", "merge")
__slots__ = ("extent", "merge", "name")
__front__ = ("extent", "merge", "name")

def __init__(self, extent=None, merge=True, name=None):
self.name = name or "i%d" % next(Index._count)
Expand Down Expand Up @@ -118,11 +118,12 @@ def __init__(self, value):

class RuntimeIndex(Scalar):
_count = itertools.count()
__slots__ = ("name", "children")
__slots__ = ("children", "name")
__back__ = ("name", )

def __init__(self, lo, hi, constraint, name=None):
self.name = name or "r%d" % next(RuntimeIndex._count)
def __init__(self, lo, hi, constraint, name):
assert name is not None, "runtime indices need a name"
self.name = name
self.children = lo, hi, constraint

@cached_property
Expand Down Expand Up @@ -173,17 +174,23 @@ def __init__(self, name):
class Argument(Terminal):
_count = defaultdict(partial(itertools.count))

__slots__ = ("shape", "dtype", "name")
__front__ = ("shape", "dtype", "name")
__slots__ = ("shape", "dtype", "_name", "prefix", "_gen_name")
__front__ = ("shape", "dtype", "_name", "prefix")

def __init__(self, shape, dtype, name=None, pfx=None):
self.dtype = dtype
self.shape = shape
if name is None:
if pfx is None:
pfx = "v"
name = "%s%d" % (pfx, next(Argument._count[pfx]))
self.name = name
self._name = name
pfx = pfx or "v"
self.prefix = pfx
self._gen_name = name or "%s%d" % (pfx, next(Argument._count[pfx]))

def get_hash(self):
return hash((type(self),) + self._cons_args(self.children) + (self.name,))

@property
def name(self):
return self._name or self._gen_name


class Literal(Terminal, Scalar):
Expand Down Expand Up @@ -218,19 +225,22 @@ def dtype(self):


class NamedLiteral(Terminal):
__slots__ = ("value", "name")
__front__ = ("value", "name")
__slots__ = ("value", "parent", "suffix")
__front__ = ("value", "parent", "suffix")

def __init__(self, value, name):
def __init__(self, value, parent, suffix):
self.value = value
self.name = name
self.parent = parent
self.suffix = suffix

def is_equal(self, other):
if type(self) != type(other):
return False
if self.shape != other.shape:
return False
if self.name != other.name:
if self.parent != other.parent:
return False
if self.suffix != other.suffix:
return False
return tuple(self.value.flat) == tuple(other.value.flat)

Expand All @@ -245,6 +255,10 @@ def shape(self):
def dtype(self):
return self.value.dtype

@property
def name(self):
return f"{self.parent.name}_{self.suffix}"


class Min(Scalar):
__slots__ = ("children", )
Expand Down
4 changes: 2 additions & 2 deletions pyop2/op2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from pyop2.sequential import READ, WRITE, RW, INC, MIN, MAX # noqa: F401
from pyop2.base import ON_BOTTOM, ON_TOP, ON_INTERIOR_FACETS, ALL # noqa: F401
from pyop2.sequential import Set, ExtrudedSet, MixedSet, Subset, DataSet, MixedDataSet # noqa: F401
from pyop2.sequential import Map, MixedMap, Sparsity, Halo # noqa: F401
from pyop2.sequential import Map, MixedMap, PermutedMap, Sparsity, Halo # noqa: F401
from pyop2.sequential import Global, GlobalDataSet # noqa: F401
from pyop2.sequential import Dat, MixedDat, DatView, Mat # noqa: F401
from pyop2.sequential import ParLoop as SeqParLoop
Expand All @@ -59,7 +59,7 @@
'MixedSet', 'Subset', 'DataSet', 'GlobalDataSet', 'MixedDataSet',
'Halo', 'Dat', 'MixedDat', 'Mat', 'Global', 'Map', 'MixedMap',
'Sparsity', 'par_loop', 'ParLoop',
'DatView']
'DatView', 'PermutedMap']


def ParLoop(kernel, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion pyop2/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pyop2.base import par_loop # noqa: F401
from pyop2.base import READ, WRITE, RW, INC, MIN, MAX # noqa: F401
from pyop2.base import ALL
from pyop2.base import Map, MixedMap, Sparsity, Halo # noqa: F401
from pyop2.base import Map, MixedMap, PermutedMap, Sparsity, Halo # noqa: F401
from pyop2.base import Set, ExtrudedSet, MixedSet, Subset # noqa: F401
from pyop2.base import DatView # noqa: F401
from pyop2.base import Kernel # noqa: F401
Expand Down
Loading

0 comments on commit 9230ed6

Please sign in to comment.