diff --git a/pyop2/codegen/rep2loopy.py b/pyop2/codegen/rep2loopy.py index b9b8d5636..1b9f9f43e 100644 --- a/pyop2/codegen/rep2loopy.py +++ b/pyop2/codegen/rep2loopy.py @@ -432,6 +432,16 @@ def generate(builder, wrapper_name=None): if wrapper_name is None: wrapper_name = "wrap_%s" % builder.kernel.name + pwaffd = isl.affs_from_space(assumptions.get_space()) + assumptions = assumptions & pwaffd["start"].ge_set(pwaffd[0]) + if builder.single_cell: + assumptions = assumptions & pwaffd["start"].lt_set(pwaffd["end"]) + else: + assumptions = assumptions & pwaffd["start"].le_set(pwaffd["end"]) + if builder.extruded: + assumptions = assumptions & pwaffd[parameters.layer_start].le_set(pwaffd[parameters.layer_end]) + assumptions = reduce(operator.and_, assumptions.get_basic_sets()) + wrapper = loopy.make_kernel(domains, statements, kernel_data=parameters.kernel_data, @@ -443,15 +453,6 @@ def generate(builder, wrapper_name=None): lang_version=(2018, 2), name=wrapper_name) - # additional assumptions - if builder.single_cell: - wrapper = loopy.assume(wrapper, "start < end") - else: - wrapper = loopy.assume(wrapper, "start <= end") - wrapper = loopy.assume(wrapper, "start >= 0") - if builder.extruded: - wrapper = loopy.assume(wrapper, "{0} <= {1}".format(parameters.layer_start, parameters.layer_end)) - # prioritize loops for indices in context.index_ordering: wrapper = loopy.prioritize_loops(wrapper, indices)