diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py index e8797888d..49183e112 100644 --- a/pyop2/codegen/builder.py +++ b/pyop2/codegen/builder.py @@ -22,7 +22,7 @@ class PetscMat(OpaqueType): def __init__(self): - super(PetscMat, self).__init__(name="Mat") + super().__init__(name="Mat") class Map(object): @@ -114,14 +114,17 @@ def pick_loop_indices(self, loop_index, layer_index=None, entity_index=None): def kernel_arg(self, loop_indices=None): pass + @abstractmethod + def emit_pack_instruction(self, *, loop_indices=None): + """Either yield an instruction, or else return an empty tuple (to indicate no instruction)""" + @abstractmethod def pack(self, loop_indices=None): pass @abstractmethod - def emit_unpack_instruction(self, *, - loop_indices=None): - pass + def emit_unpack_instruction(self, *, loop_indices=None): + """Either yield an instruction, or else return an empty tuple (to indicate no instruction)""" class GlobalPack(Pack): @@ -133,12 +136,20 @@ def __init__(self, outer, access): def kernel_arg(self, loop_indices=None): return Indexed(self.outer, (Index(e) for e in self.outer.shape)) + def emit_pack_instruction(self, *, loop_indices=None): + shape = self.outer.shape + if self.access is WRITE: + zero = Zero((), self.outer.dtype) + multiindex = MultiIndex(*(Index(e) for e in shape)) + yield Accumulate(PackInst(), Indexed(self.outer, multiindex), zero) + else: + return () + def pack(self, loop_indices=None): return None - def emit_unpack_instruction(self, *, - loop_indices=None): - yield None + def emit_unpack_instruction(self, *, loop_indices=None): + return () class DatPack(Pack): @@ -215,13 +226,15 @@ def kernel_arg(self, loop_indices=None): shape = pack.shape return Indexed(pack, (Index(e) for e in shape)) - def emit_unpack_instruction(self, *, - loop_indices=None): + def emit_pack_instruction(self, *, loop_indices=None): + return () + + def emit_unpack_instruction(self, *, loop_indices=None): pack = self.pack(loop_indices) if pack is None: - yield None + return () elif self.access is READ: - yield None + return () elif self.access in {INC, MIN, MAX}: op = {INC: Sum, MIN: Min, @@ -295,10 +308,13 @@ def kernel_arg(self, loop_indices=None): shape = pack.shape return Indexed(pack, (Index(e) for e in shape)) + def emit_pack_instruction(self, *, loop_indices=None): + return () + def emit_unpack_instruction(self, *, loop_indices=None): pack = self.pack(loop_indices) if self.access is READ: - yield None + return () else: if self.interior_horizontal: _shape = (2,) @@ -368,8 +384,10 @@ def kernel_arg(self, loop_indices=None): pack = self.pack(loop_indices=loop_indices) return Indexed(pack, tuple(Index(e) for e in pack.shape)) - def emit_unpack_instruction(self, *, - loop_indices=None): + def emit_pack_instruction(self, *, loop_indices=None): + return () + + def emit_unpack_instruction(self, *, loop_indices=None): from pyop2.codegen.rep2loopy import register_petsc_function ((rdim, cdim), ), = self.dims rmap, cmap = self.maps @@ -428,7 +446,6 @@ class WrapperBuilder(object): def __init__(self, *, iterset, iteration_region=None, single_cell=False, pass_layer_to_kernel=False, forward_arg_types=()): - super().__init__() self.arguments = [] self.argument_accesses = [] self.packed_args = [] @@ -658,14 +675,13 @@ def kernel_call(self): return FunctionCall(self.kernel.name, KernelInst(), access, free_indices, *args) def emit_instructions(self): + yield from itertools.chain(*(pack.emit_pack_instruction(loop_indices=self.loop_indices) + for pack in self.packed_args)) # Sometimes, actual instructions do not refer to all the loop # indices (e.g. all of them are globals). To ensure that loopy # knows about these indices, we emit a dummy instruction (that # doesn't generate any code) that does depend on them. yield DummyInstruction(PackInst(), *(x for x in self.loop_indices if x is not None)) yield self.kernel_call() - for pack in self.packed_args: - insns = pack.emit_unpack_instruction(loop_indices=self.loop_indices) - for insn in insns: - if insn is not None: - yield insn + yield from itertools.chain(*(pack.emit_unpack_instruction(loop_indices=self.loop_indices) + for pack in self.packed_args))