diff --git a/devito/ir/iet/scheduler.py b/devito/ir/iet/scheduler.py index 51b72a4c83..ae8db93775 100644 --- a/devito/ir/iet/scheduler.py +++ b/devito/ir/iet/scheduler.py @@ -7,7 +7,7 @@ ExpressionBundle, Transformer, FindNodes, FindSymbols, MapExprStmts, XSubs, iet_analyze) from devito.symbolics import IntDiv, ccode, xreplace_indices -from devito.tools import as_mapper, as_tuple +from devito.tools import as_mapper, as_tuple, flatten from devito.types import ConditionalDimension __all__ = ['iet_build', 'iet_insert_decls', 'iet_insert_casts'] @@ -168,7 +168,7 @@ def iet_insert_decls(iet, external): continue elif i._mem_stack: # On the stack - allocator.push_object_on_stack(iet[0], i) + allocator.push_array_on_stack(iet[0], i) else: # On the heap allocator.push_array_on_heap(i) @@ -199,16 +199,21 @@ def __init__(self): self.stack = OrderedDict() def push_object_on_stack(self, scope, obj): - """Define an Array or a composite type (e.g., a struct) on the stack.""" + """Define a LocalObject on the stack.""" handle = self.stack.setdefault(scope, OrderedDict()) + handle[obj] = Element(c.Value(obj._C_typename, obj.name)) - if obj.is_LocalObject: - handle[obj] = Element(c.Value(obj._C_typename, obj.name)) - else: - shape = "".join("[%s]" % ccode(i) for i in obj.symbolic_shape) - alignment = "__attribute__((aligned(%d)))" % obj._data_alignment - value = "%s%s %s" % (obj.name, shape, alignment) - handle[obj] = Element(c.POD(obj.dtype, value)) + def push_array_on_stack(self, scope, obj): + """Define an Array on the stack.""" + handle = self.stack.setdefault(scope, OrderedDict()) + + if obj in flatten(self.stack.values()): + return + + shape = "".join("[%s]" % ccode(i) for i in obj.symbolic_shape) + alignment = "__attribute__((aligned(%d)))" % obj._data_alignment + value = "%s%s %s" % (obj.name, shape, alignment) + handle[obj] = Element(c.POD(obj.dtype, value)) def push_scalar_on_stack(self, scope, expr): """Define a Scalar on the stack.""" diff --git a/tests/test_operator.py b/tests/test_operator.py index 789173d53b..682c4ddd5c 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -6,10 +6,12 @@ SparseFunction, SparseTimeFunction, Dimension, error, SpaceDimension, NODE, CELL, dimensions, configuration, TensorFunction, TensorTimeFunction, VectorFunction, VectorTimeFunction) -from devito.ir.iet import (Expression, Iteration, FindNodes, IsPerfectIteration, +from devito.ir.equations import ClusterizedEq +from devito.ir.iet import (Conditional, Expression, Iteration, FindNodes, + IsPerfectIteration, derive_parameters, iet_insert_decls, retrieve_iteration_tree) from devito.ir.support import Any, Backward, Forward -from devito.symbolics import indexify, retrieve_indexed +from devito.symbolics import ListInitializer, indexify, retrieve_indexed from devito.tools import flatten from devito.types import Array, Scalar @@ -1143,6 +1145,22 @@ def test_stack_vector_temporaries(self): timers->section0 += (double)(end_section0.tv_sec-start_section0.tv_sec)\ +(double)(end_section0.tv_usec-start_section0.tv_usec)/1000000;""" in str(operator) + def test_conditional_declarations(self): + a = Array(name='a', dimensions=(x,), dtype=np.int32, scope='stack') + list_initialize = Expression(ClusterizedEq(Eq(a, ListInitializer([0, 0])))) + iet = Conditional(x < 3, list_initialize, list_initialize) + parameters = derive_parameters(iet, True) + iet = iet_insert_decls(iet, parameters) + assert str(iet[0]) == """\ +if (x < 3) +{ + int a[x_size] = {0, 0}; +} +else +{ + int a[x_size] = {0, 0}; +}""" + class TestLoopScheduling(object): diff --git a/tests/test_ops.py b/tests/test_ops.py index 64e4d97a31..042656114c 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -13,13 +13,15 @@ # thus invalidating all of the future tests. This is guaranteed by the # `pytestmark` above from devito import Eq, Function, Grid, Operator, TimeFunction, configuration # noqa +from devito.ir.equations import ClusterizedEq # noqa +from devito.ir.iet import Conditional, Expression, derive_parameters, iet_insert_decls # noqa from devito.ops.node_factory import OPSNodeFactory # noqa from devito.ops.transformer import create_ops_arg, create_ops_dat, make_ops_ast, to_ops_stencil # noqa -from devito.ops.types import OpsAccessible, OpsDat, OpsStencil, OpsBlock # noqa +from devito.ops.types import Array, OpsAccessible, OpsDat, OpsStencil, OpsBlock # noqa from devito.ops.utils import namespace, AccessibleInfo, OpsDatDecl, OpsArgDecl # noqa -from devito.symbolics import Byref, Literal, indexify # noqa +from devito.symbolics import Byref, ListInitializer, Literal, indexify # noqa from devito.tools import dtype_to_cstr # noqa -from devito.types import Buffer, Constant, Symbol # noqa +from devito.types import Buffer, Constant, DefaultDimension, Symbol # noqa class TestOPSExpression(object): @@ -272,11 +274,24 @@ def test_create_ops_block(self, equation, expected): ]) def test_upper_bound(self, equation, expected): grid = Grid((5, 5)) - u = TimeFunction(name='u', grid=grid) # noqa + u = TimeFunction(name='u', grid=grid) # noqa op = Operator(eval(equation)) assert expected in str(op.ccode) + @pytest.mark.parametrize('equation, declaration', [ + ('Eq(u.forward, u+1)', + 'int OPS_Kernel_0_range[4]') + ]) + def test_single_declaration(self, equation, declaration): + grid = Grid((5, 5)) + u = TimeFunction(name='u', grid=grid) # noqa + op = Operator(eval(equation)) + + occurrences = [i for i in str(op.ccode).split('\n') if declaration in i] + + assert len(occurrences) == 1 + @pytest.mark.parametrize('equation,expected', [ ('Eq(u_2d.forward, u_2d + 1)', '[\'ops_dat_fetch_data(u_dat[(time_M)%(2)],0,&(u[(time_M)%(2)]));\','