diff --git a/src/psyclone/domain/lfric/__init__.py b/src/psyclone/domain/lfric/__init__.py index 158ed4e980..2eb640cc88 100644 --- a/src/psyclone/domain/lfric/__init__.py +++ b/src/psyclone/domain/lfric/__init__.py @@ -67,6 +67,8 @@ from psyclone.domain.lfric.lfric_loop import LFRicLoop from psyclone.domain.lfric.lfric_kern_call_factory import LFRicKernCallFactory from psyclone.domain.lfric.lfric_collection import LFRicCollection +from psyclone.domain.lfric.formal_kernel_args_from_metadata import \ + FormalKernelArgsFromMetadata from psyclone.domain.lfric.lfric_fields import LFRicFields from psyclone.domain.lfric.lfric_run_time_checks import LFRicRunTimeChecks from psyclone.domain.lfric.lfric_invokes import LFRicInvokes @@ -81,6 +83,7 @@ __all__ = [ 'ArgOrdering', + 'FormalKernelArgsFromMetadata', 'FunctionSpace', 'KernCallAccArgList', 'KernCallArgList', diff --git a/src/psyclone/domain/lfric/formal_kernel_args_from_metadata.py b/src/psyclone/domain/lfric/formal_kernel_args_from_metadata.py new file mode 100644 index 0000000000..851488efea --- /dev/null +++ b/src/psyclone/domain/lfric/formal_kernel_args_from_metadata.py @@ -0,0 +1,1328 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2023-2024, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Author: R. W. Ford, STFC Daresbury Lab + +'''This module implements a class that takes LFRic kernel metadata as +input and outputs the expected kernel subroutine arguments (based on +this metadata) within an LFRic PSyIR symbol table. + +''' +from psyclone.domain import lfric +from psyclone.errors import InternalError +from psyclone.psyir import nodes, symbols + +# TODO: many symbols here are not declared in LFRicTypes + + +class FormalKernelArgsFromMetadata(lfric.MetadataToArgumentsRules): + '''Provides the expected kernel subroutine arguments within an LFRic + PSyIR symbol table based on the provided LFRic Kernel metadata. + + ''' + _access_lookup = { + "gh_read": symbols.ArgumentInterface.Access.READ, + "gh_write": symbols.ArgumentInterface.Access.WRITE, + "gh_readwrite": symbols.ArgumentInterface.Access.READWRITE, + "gh_inc": symbols.ArgumentInterface.Access.READWRITE, + "gh_sum": symbols.ArgumentInterface.Access.READWRITE} + + # It is clearer to use 'symbol_table' here rather than 'info' + # pylint: disable=arguments-renamed + @classmethod + def _initialise(cls, symbol_table=None): + '''Initialise any additional state for this class. + + :param symbol_table: the symbol table that the kernel + arguments should be added to. If it is set to None then a new + symbol table is created. + :type symbol_table: + Optional[:py:class:`psyclone.psyir.symbols.SymbolTable`] + + :raises TypeError: if the symbol_table argument is an + unexpected type. + + ''' + # TODO: Should this be an LFRicSymbolTable? + if not symbol_table: + symbol_table = symbols.SymbolTable() + elif not isinstance(symbol_table, symbols.SymbolTable): + raise TypeError( + f"Expecting the optional 'symbol_table' argument to be a " + f"SymbolTable but found {type(symbol_table).__name__}.") + # We could use super()._initialise(symbol_table) here but it is simpler + # to just set the value of _info directly and it avoids pylint + # complaining. + cls._info = symbol_table + + @classmethod + def _add_precision_symbol(cls, sym): + ''' + ''' + if sym.name in cls._info: + return + const = lfric.LFRicConstants() + mod_name = const.UTILITIES_MOD_MAP["constants"]["module"] + mod_sym = cls._info.find_or_create(mod_name, + symbol_type=symbols.ContainerSymbol) + sym.interface = symbols.ImportInterface(mod_sym) + cls._info.add(sym) + + @classmethod + def _cell_position(cls): + ''''cell' argument providing the cell position. This is an integer of + type i_def and has intent in. + + ''' + cls._add_lfric_symbol_name("CellPositionDataSymbol", "cell") + + @classmethod + def _mesh_height(cls): + ''''nlayers' argument providing the mesh height. This is an integer + of type i_def and has intent in. + + ''' + # TODO: if self._kern.iterates_over not in ["cell_column", "domain"]: + # return + cls._add_lfric_symbol_name("MeshHeightDataSymbol", "nlayers") + + @classmethod + def _mesh_ncell2d_no_halos(cls): + ''''ncell_2d_no_halos' argument providing the number of columns in + the mesh ignoring halos. This is an integer of type i_def and + has intent in. + + ''' + cls._add_lfric_symbol_name( + "LFRicIntegerScalarDataSymbol", "ncell_2d_no_halos") + + @classmethod + def _mesh_ncell2d(cls): + ''''ncell_2d' argument providing the number of columns in the mesh + including halos. This is an integer of type i_def and has + intent in. + + ''' + # This symbol might be used to dimension an array in another + # method and therefore could have already been declared. + symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "ncell_2d") + cls._append_to_arg_list(symbol) + + @classmethod + def _cell_map(cls): + '''Four arguments providing a mapping from coarse to fine mesh for the + current column. The first is 'cell_map', an integer array of + rank two, kind i_def and intent in. This is followed by its + extents, 'ncell_f_per_c_x' and 'ncell_f_per_c_y' the numbers + of fine cells per coarse cell in the x and y directions, + respectively. These are integers of kind i_def and have intent + in. Lastly is 'ncell_f', the number of cells (columns) in the + fine mesh. This is an integer of kind i_def and has intent in. + + ''' + # Create the cell_map array extent symbols + x_symbol = cls._create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "ncell_f_per_c_x") + y_symbol = cls._create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "ncell_f_per_c_y") + # Use the array extent symbols to create the cell_map array symbol + scalar_type = cls._create_datatype("LFRicIntegerScalarDataType") + array_type = symbols.ArrayType( + scalar_type, [nodes.Reference(x_symbol), + nodes.Reference(y_symbol)]) + interface = symbols.ArgumentInterface( + symbols.ArgumentInterface.Access.READ) + cell_map_symbol = symbols.DataSymbol( + "cell_map", array_type, interface=interface) + cls._add_to_symbol_table(cell_map_symbol) + # Add the symbols to the symbol table argument list in the + # required order + cls._append_to_arg_list(cell_map_symbol) + cls._append_to_arg_list(x_symbol) + cls._append_to_arg_list(y_symbol) + # TODO Should be get or create + cls._add_lfric_symbol_name("LFRicIntegerScalarDataSymbol", "ncell_f") + + @classmethod + def _scalar(cls, meta_arg): + '''Argument providing an LFRic scalar value. + + :param meta_arg: the metadata associated with this scalar argument. + :type meta_arg: \ + :py:class:`psyclone.domain.lfric.kernel.ScalarArgMetadata` + + ''' + # TODO: This should be a meta_arg function or a mapping, or + # part of LFRicTypes? + datatype = meta_arg.datatype[3:] + # TODO: The name should come from meta_arg classes, or part of + # LFRicTypes? + datatype_char = meta_arg.datatype[3:4] + meta_arg_index = cls._metadata.meta_args.index(meta_arg) + cls._add_lfric_symbol_name( + f"LFRic{datatype.capitalize()}ScalarDataSymbol", + f"{datatype_char}scalar_{meta_arg_index+1}", + access=cls._access_lookup[meta_arg.access]) + + @classmethod + def _field(cls, meta_arg): + '''Argument providing an LFRic field. + + :param meta_arg: the metadata associated with this field argument. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata` + + ''' + # This symbol is used in other methods so might have already + # been declared. + undf_name = cls._undf_name(meta_arg.function_space) + undf_symbol = cls._get_or_create_lfric_symbol( + "NumberOfUniqueDofsDataSymbol", undf_name) + + name = cls._field_name(meta_arg) + datatype = meta_arg.datatype[3:] + cls._add_lfric_symbol_name( + f"{datatype.capitalize()}FieldDataSymbol", name, + dims=[undf_symbol], + access=cls._access_lookup[meta_arg.access]) + + @classmethod + def _field_vector(cls, meta_arg): + '''Arguments providing an LFRic field vector. + + :param meta_arg: the metadata associated with this field + vector argument. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldVectorArgMetadata` + + ''' + # This symbol is used in other methods so might have already + # been declared. + undf_name = cls._undf_name(meta_arg.function_space) + undf_symbol = cls._get_or_create_lfric_symbol( + "NumberOfUniqueDofsDataSymbol", undf_name) + + field_name = cls._field_name(meta_arg) + datatype = meta_arg.datatype[3:] + access = cls._access_lookup[meta_arg.access] + for idx in range(int(meta_arg.vector_length)): + name = f"{field_name}_v{idx+1}" + cls._add_lfric_symbol_name( + f"{datatype.capitalize()}VectorFieldDataSymbol", name, + dims=[undf_symbol], access=access) + + @classmethod + def _operator(cls, meta_arg): + '''Arguments providing an LMA operator. First include an integer + extent of kind i_def with intent in. The default name of this + extent is '_ncell_3d'. Next include the + operator. This is a rank-3, real array. Its precision (kind) + depends on how it is defined in the algorithm layer and its + intent depends on its metadata. Its default name is + 'op_'. The extents of the first two + dimensions are the local degrees of freedom for the to and + from function spaces, respectively, and that of the third is + '_ncell_3d'. + + :param meta_arg: the metadata associated with the operator + arguments. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.OperatorArgMetadata` + + ''' + meta_arg_index = cls._metadata.meta_args.index(meta_arg) + name = f"op_{meta_arg_index+1}_ncell_3d" + op_ncell_3d_symbol = cls._create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", name) + cls._append_to_arg_list(op_ncell_3d_symbol) + + # This symbol is used in other methods so might have already + # been declared. + ndf_name_to = cls._ndf_name(meta_arg.function_space_to) + ndf_to_symbol = cls._get_or_create_lfric_symbol( + "NumberOfDofsDataSymbol", ndf_name_to) + + # This symbol is used in other methods so might have already + # been declared. + ndf_name_from = cls._ndf_name(meta_arg.function_space_from) + ndf_from_symbol = cls._get_or_create_lfric_symbol( + "NumberOfDofsDataSymbol", ndf_name_from) + + operator_name = cls._operator_name(meta_arg) + access = cls._access_lookup[meta_arg.access] + cls._add_lfric_symbol_name( + "OperatorDataSymbol", operator_name, + dims=[ndf_to_symbol, ndf_from_symbol, op_ncell_3d_symbol], + access=access) + + @classmethod + def _cma_operator(cls, meta_arg): + '''Arguments providing a columnwise operator. First include a real, + 3-dimensional array of kind r_solver with its intent depending + on its metadata. Its default name is + 'cma_op_', hereon specified as + '' and the default names of its dimensions are + 'bandwidth_', 'nrow_', and + 'ncell_2d'. Next the number of rows in the banded matrix is + provided. This is an integer of kind i_def with intent in with + default name 'nrow_'. If the from-space of the + operator is not the same as the to-space then the number of + columns in the banded matrix is provided next. This is an + integer of kind i_def with intent in and has default name + 'ncol_'. Next the bandwidth of the banded + matrix is added. This is an integer of kind i_def with intent + in and has default name 'bandwidth_'. Next + banded-matrix parameter alpha is added. This is an integer of + kind i_def with intent in and has default name + 'alpha_'. Next banded-matrix parameter beta is + added. This is an integer of kind i_def with intent in and has + default name 'beta_. Next banded-matrix + parameter gamma_m is added. This is an integer of kind i_def + with intent in and has default name 'gamma_m_'. + Finally banded-matrix parameter gamma_p is added. This is an + integer of kind i_def with intent in and has default name + 'gamma_p_'. + + :param meta_arg: the metadata associated with the CMA operator + arguments. + :type meta_arg: :py:class:`psyclone.domain.lfric.kernel. + ColumnwiseOperatorArgMetadata` + + ''' + operator_name = cls._cma_operator_name(meta_arg) + + bandwidth = cls._create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"bandwidth_{operator_name}") + nrow = cls._create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"nrow_{operator_name}") + # This symbol is used in other methods so might have already + # been declared. + ncell_2d = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "ncell_2d") + + access = cls._access_lookup[meta_arg.access] + # TODO: should be r_solver precision + cls._add_lfric_symbol_name( + "OperatorDataSymbol", operator_name, + dims=[bandwidth, nrow, ncell_2d], access=access) + cls._append_to_arg_list(nrow) + if meta_arg.function_space_from != meta_arg.function_space_to: + cls._add_lfric_symbol_name( + "LFRicIntegerScalarDataSymbol", f"ncol_{operator_name}") + cls._append_to_arg_list(bandwidth) + cls._add_lfric_symbol_name( + "LFRicIntegerScalarDataSymbol", f"alpha_{operator_name}") + cls._add_lfric_symbol_name( + "LFRicIntegerScalarDataSymbol", f"beta_{operator_name}") + cls._add_lfric_symbol_name( + "LFRicIntegerScalarDataSymbol", f"gamma_m_{operator_name}") + cls._add_lfric_symbol_name( + "LFRicIntegerScalarDataSymbol", f"gamma_p_{operator_name}") + + @classmethod + def _ref_element_properties(cls, meta_ref_element): + '''Arguments required if there are reference element properties + specified in the metadata. + + If either the normals_to_horizontal_faces or + outward_normals_to_horizontal_faces properties of the + reference element are required then pass the number of + horizontal faces of the reference element, with default name + nfaces_re_h. Similarly, if either the + normals_to_vertical_faces or outward_normals_to_vertical_faces + are required then pass the number of vertical faces, with + default name 'nfaces_re_v'. This also holds for the + normals_to_faces and outward_normals_to_faces where the number + of all faces of the reference element (with default name + nfaces_re) is passed to the kernel. All of these quantities + are integers of kind i_def with intent in. + + Then, in the order specified in the meta_reference_element + metadata: For the + [outward_]normals_to_horizontal/[outward_]vertical_faces, pass + a rank-2 integer array of kind i_def with dimensions (3, + nfaces_re_h/v) and intent in. For normals_to_faces or + outward_normals_to_faces pass a rank-2 integer array of kind + i_def with dimensions (3, nfaces_re) and intent in. In each + case the default name is the same name as the reference + element property. + + :param meta_ref_element: the metadata capturing the reference + element properties required by the kernel. + :type meta_ref_element: List[:py:class:`psyclone.domain.lfric. + kernel.MetaRefElementArgMetadata`] + + ''' + if [entry for entry in meta_ref_element if entry.reference_element in + ["normals_to_horizontal_faces", + "outward_normals_to_horizontal_faces"]]: + nfaces_re_h = cls._create_lfric_symbol( + "NumberOfQrPointsInXyDataSymbol", "nfaces_re_h") + cls._append_to_arg_list(nfaces_re_h) + + if [entry for entry in meta_ref_element if entry.reference_element in + ["normals_to_vertical_faces", + "outward_normals_to_vertical_faces"]]: + nfaces_re_v = cls._create_lfric_symbol( + "NumberOfQrPointsInZDataSymbol", "nfaces_re_v") + cls._append_to_arg_list(nfaces_re_v) + + if [entry for entry in meta_ref_element if entry.reference_element in + ["normals_to_faces", + "outward_normals_to_faces"]]: + nfaces_re = cls._create_lfric_symbol( + "NumberOfQrPointsInFacesDataSymbol", "nfaces_re") + cls._append_to_arg_list(nfaces_re) + + scalar_type = cls._create_datatype("LFRicIntegerScalarDataType") + interface = symbols.ArgumentInterface( + symbols.ArgumentInterface.Access.READ) + + for ref_element_property in meta_ref_element: + + if ref_element_property.reference_element in [ + "normals_to_horizontal_faces", + "outward_normals_to_horizontal_faces"]: + ref_element_dim = nfaces_re_h + elif ref_element_property.reference_element in [ + "normals_to_vertical_faces", + "outward_normals_to_vertical_faces"]: + ref_element_dim = nfaces_re_v + elif ref_element_property.reference_element in [ + "normals_to_faces" or "outward_normals_to_faces"]: + ref_element_dim = nfaces_re + else: + raise InternalError( + f"Unsupported reference element property " + f"'{ref_element_property.reference_element}' found.") + + array_type = symbols.ArrayType( + scalar_type, [nodes.Literal("3", scalar_type), + nodes.Reference(ref_element_dim)]) + property_symbol = symbols.DataSymbol( + ref_element_property.reference_element, array_type, + interface=interface) + cls._add_to_symbol_table(property_symbol) + cls._append_to_arg_list(property_symbol) + + @classmethod + def _mesh_properties(cls, meta_mesh): + '''All arguments required for mesh properties specified in the kernel + metadata. + + If the adjacent_face mesh property is required then, if the + number of horizontal cell faces obtained from the reference + element (nfaces_re_h) is not already being passed to the + kernel via the reference element then supply it here. This is + an integer of kind i_def with intent in and has default name + 'nfaces_re_h'. + + Also pass a rank-1, integer array with intent in of kind i_def + and extent nfaces_re_h, with default name being the name of + the property (adjacent_face in this case). + + :param meta_mesh: the metadata capturing the mesh properties + required by the kernel. + :type meta_mesh: List[ + :py:class:`psyclone.domain.lfric.kernel.MetaMeshArgMetadata`] + + raises InternalError: if the mesh property is not 'adjacent_face'. + + ''' + for mesh_property in meta_mesh: + if mesh_property.mesh == "adjacent_face": + # nfaces_re_h may have been passed via the reference + # element logic. + nfaces_re_h = cls._get_or_create_lfric_symbol( + "NumberOfQrPointsInXyDataSymbol", "nfaces_re_h") + if not cls._metadata.meta_ref_element or not \ + [entry for entry in cls._metadata.meta_ref_element + if entry.reference_element in [ + "normals_to_horizontal_faces", + "outward_normals_to_horizontal_faces"]]: + # nfaces_re_h was not been passed via the + # reference element logic so add it the argument + # list. + cls._append_to_arg_list(nfaces_re_h) + interface = symbols.ArgumentInterface( + symbols.ArgumentInterface.Access.READ) + scalar_type = cls._create_datatype( + "LFRicIntegerScalarDataType") + array_type = symbols.ArrayType( + scalar_type, [nodes.Reference(nfaces_re_h)]) + mesh_symbol = symbols.DataSymbol( + mesh_property.mesh, array_type, + interface=interface) + cls._add_to_symbol_table(mesh_symbol) + cls._append_to_arg_list(mesh_symbol) + else: + raise InternalError( + f"Unexpected mesh property '{mesh_property.mesh}' found. " + f"Expected 'adjacent_face'.") + + @classmethod + def _fs_common(cls, function_space): + '''Arguments associated with a function space that are common to + fields and operators. Add the number of degrees of freedom for + this function space. This is an integer of kind i_def with + intent in and default name 'ndf_'. + + :param str function_space: the current function space. + + ''' + # TODO: if self._kern.iterates_over not in ["cell_column", "domain"]: + # return + + function_space_name = cls._function_space_name(function_space) + ndf_name = cls._ndf_name(function_space) + # This symbol might be used to dimension an array in another + # method and therefore could have already been declared. + symbol = cls._get_or_create_lfric_symbol("NumberOfDofsDataSymbol", + ndf_name) + cls._append_to_arg_list(symbol) + + @classmethod + def _fs_compulsory_field(cls, function_space): + '''Compulsory arguments for this function space. First include the + unique number of degrees of freedom for this function + space. This is a scalar integer of kind i_def with intent in + and it's default name is 'undf'_. Second + include the dof map for this function space. This is a 1D + integer array with dimension 'ndf_' of kind + i_def with intent in and its' default name is + 'map_. + + :param str function_space: the current function space. + + ''' + undf_name = cls._undf_name(function_space) + # This symbol might be used to dimension an array in another + # method and therefore could have already been declared. + undf_symbol = cls._get_or_create_lfric_symbol( + "NumberOfUniqueDofsDataSymbol", undf_name) + + # TODO: if domain pass whole dofmap + + # get ndf + ndf_name = cls._ndf_name(function_space) + ndf_symbol = cls._get_or_create_lfric_symbol( + "NumberOfDofsDataSymbol", ndf_name) + # create dofmap(ndf) + dofmap_name = cls._dofmap_name(function_space) + dofmap_symbol = cls._create_array_symbol( + "LFRicIntegerScalarDataType", dofmap_name, + dims=[ndf_symbol]) + + # Add the symbols to the symbol table argument list in the + # required order + cls._append_to_arg_list(undf_symbol) + cls._append_to_arg_list(dofmap_symbol) + + @classmethod + def _fs_intergrid(cls, meta_arg): + '''Function-space related arguments for an intergrid kernel. + + For this field include the required dofmap information. + + If the dofmap is associated with an argument on the fine mesh, + include the number of DoFs per cell for the FS of the field on + the fine mesh. This is an integer with intent in and precision + i_def. Its default name is 'ndf_'. Next, + include the number of unique DoFs per cell for the FS of the + field on the fine mesh. This is an integer with intent in and + precision i_def. Its default name is + 'undf_'. Lastly include the whole dofmap for + the fine mesh. This is an integer array of rank two and kind + i_def with intent in. The extent of the first dimension is + ndf_ and that of the second is ncell_f. Its + default name is 'full_map_'. + + If the dofmap is associated with an argument on the coarse + mesh then include undf_coarse, the number of unique DoFs for + the coarse field. This is an integer of kind i_def with intent + in. Its default name is 'undf_'. Lastly, + include the dofmap for the current cell (column) in the coarse + mesh. This is an integer array of rank one, kind i_def and + has intent in. Its default name is 'map_'. + + :param meta_arg: the metadata capturing the InterGrid argument + required by the kernel. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.InterGridArgMetadata`] + + ''' + function_space = meta_arg.function_space + function_space_name = cls._function_space_name(function_space) + if meta_arg.mesh_arg == "gh_fine": + # add ndf symbol + cls._fs_common(function_space) + # add undf symbol (may have already been declared) + undf_name = cls._undf_name(function_space) + symbol = cls._get_or_create_lfric_symbol( + "NumberOfUniqueDofsDataSymbol", undf_name) + cls._append_to_arg_list(symbol) + # get ndf_symbol (has just been added) + ndf_name = cls._ndf_name(function_space) + ndf_symbol = cls._info.lookup(ndf_name) + # get ncell_f symbol (may have already been declared) + ncell_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "ncell_f") + # add full_dofmap(ndf, ncell_f) + fullmap_name = cls._fullmap_name(function_space) + cls._add_array_symbol_name( + "LFRicIntegerScalarDataType", fullmap_name, + dims=[ndf_symbol, ncell_symbol]) + else: # "gh_coarse" + # undf + dofmap + cls._fs_compulsory_field(function_space) + + @classmethod + def _basis_or_diff_basis_dimension(cls, name, function_space): + ''' xxx ''' + if name == "basis": + return cls._basis_dimension(function_space) + elif name == "diff_basis": + return cls._diff_basis_dimension(function_space) + else: + raise Exception("xxx") + + @staticmethod + def _basis_dimension(function_space): + ''' xxx ''' + fs_1_list = ["w0", "w2trace", "w2htrace", "w2vtrace", "w3", + "wtheta", "wchi"] + fs_3_list = ["w1", "w2", "w2h", "w2v", "w2broken", "any_w2"] + if function_space.lower() in fs_1_list: + return 1 + elif function_space.lower() in fs_3_list: + return 3 + else: + raise ValueError( + f"Unexpected function space value '{function_space}' found " + f"in basis_dimension. Expected one of {fs_1_list+fs_3_list}.") + + @staticmethod + def _diff_basis_dimension(function_space): + ''' xxx ''' + fs_1_list = ["w2", "w2h", "w2v", "w2broken", "any_w2"] + fs_3_list = ["w0", "w1", "w2trace", "w2htrace", "w2vtrace", "w3", + "wtheta", "wchi"] + if function_space.lower() in fs_1_list: + return 1 + elif function_space.lower() in fs_3_list: + return 3 + else: + raise ValueError( + f"Unexpected function space value '{function_space}' found " + f"in diff_basis_dimension. Expected one of " + f" {fs_1_list+fs_3_list}.") + + @classmethod + def _basis_or_diff_basis(cls, name, function_space): + '''Utility function for the basis and diff_basis methods. + + For each operation on the function space (name = ["basis", + "diff_basis"]), in the order specified in the metadata, pass + real arrays of kind r_def with intent in. For each shape + specified in the gh_shape metadata entry: + + If shape is gh_quadrature_* then the arrays are of rank four + and have default name + "_"_. + + If shape is gh_quadrature_xyoz then the arrays have extent + (dimension, number_of_dofs, np_xy, np_z). + + If shape is gh_quadrature_face or gh_quadrature_edge then the + arrays have extent (dimension, number_of_dofs, np_xyz, nfaces + or nedges). + + If shape is gh_evaluator then pass one array for each target + function space (i.e. as specified by + gh_evaluator_targets). Each of these arrays are of rank three + with extent (dimension, number_of_dofs, + ndf_). The default name of the argument + is "_""_on_". + + Here is the name of the corresponding + quadrature object being passed to the Invoke. dimension is 1 + or 3 and depends upon the function space and whether or not + it is a basis or a differential basis function (see the table + below). number_of_dofs is the number of degrees of freedom + (DoFs) associated with the function space and np_* are the + number of points to be evaluated: i) *_xyz in all directions + (3D); ii) *_xy in the horizontal plane (2D); iii) *_x, *_y in + the horizontal (1D); and iv) *_z in the vertical (1D). nfaces + and nedges are the number of horizontal faces/edges obtained + from the appropriate quadrature object supplied to the Invoke. + + Function Type Dimension Function Space Name + + Basis 1 W0, W2trace, W2Htrace, W2Vtrace, W3, + Wtheta, Wchi + 3 W1, W2, W2H, W2V, W2broken, ANY_W2 + + Differential Basis 1 W2, W2H, W2V, W2broken, ANY_W2 + + 3 W0, W1, W2trace, W2Htrace, W2Vtrace, + W3, Wtheta, Wchi + + :param str name: 'basis' or 'diff_basis'. + :param str function_space: the current function space. + + raises InternalError: if unexpected shape metadata is found. + + ''' + function_space_name = cls._function_space_name(function_space) + const = lfric.LFRicConstants() + if not cls._metadata.shapes: + return + dimension = cls._basis_or_diff_basis_dimension(name, function_space) + scalar_datatype = cls._create_datatype("LFRicIntegerScalarDataType") + dimension_literal = nodes.Literal(str(dimension), scalar_datatype) + # This symbol is used in other methods so might have already + # been declared. + ndf_name = cls._ndf_name(function_space) + ndf_symbol = cls._get_or_create_lfric_symbol( + "NumberOfDofsDataSymbol", ndf_name) + + for shape in cls._metadata.shapes: + if shape in const.VALID_QUADRATURE_SHAPES: + quad_name = shape.split('_')[-1] + if shape == "gh_quadrature_xyoz": + np_xy_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", + f"np_xy_qr_{quad_name}") + np_z_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"np_z_qr_{quad_name}") + cls._append_to_arg_list(np_xy_symbol) + cls._append_to_arg_list(np_z_symbol) + dims = [dimension_literal, ndf_symbol, np_xy_symbol, + np_z_symbol] + elif shape == "gh_quadrature_face": + np_xyz_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"np_xyz_qr_{quad_name}") + nfaces_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"nfaces_qr_{quad_name}") + cls._append_to_arg_list(np_xyz_symbol) + cls._append_to_arg_list(nfaces_symbol) + dims = [dimension_literal, ndf_symbol, np_xyz_symbol, + nfaces_symbol] + elif shape == "gh_quadrature_edge": + np_xyz_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"np_xyz_qr_{quad_name}") + nedges_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"nedges_qr_{quad_name}") + cls._append_to_arg_list(np_xyz_symbol) + cls._append_to_arg_list(nedges_symbol) + dims = [dimension_literal, ndf_symbol, np_xyz_symbol, + nedges_symbol] + else: + raise Exception("xxx") + cls._add_array_symbol_name( + "LFRicRealScalarDataType", + f"{name}_{function_space_name}_qr_{quad_name}", + dims) + elif shape in const.VALID_EVALUATOR_SHAPES: + if cls._metadata.evaluator_targets: + target_function_spaces = cls._metadata.evaluator_targets + else: + # Targets are the function spaces of all modified fields. + fields = [field for field in cls._metadata.meta_args + if type(field) in [ + lfric.kernel.FieldArgMetadata, + lfric.kernel.FieldVectorArgMetadata, + lfric.kernel.InterGridArgMetadata, + lfric.kernel.InterGridVectorArgMetadata] + and field.access != "gh_read"] + target_function_spaces = [] + for field in fields: + if field.function_space not in target_function_spaces: + target_function_spaces.append(field.function_space) + for target_function_space in target_function_spaces: + + target_ndf_name = cls._ndf_name(target_function_space) + target_ndf_symbol = cls._get_or_create_lfric_symbol( + "NumberOfDofsDataSymbol", ndf_name) + + dims = [dimension_literal, ndf_symbol, target_ndf_symbol] + target_function_space_name = cls._function_space_name( + target_function_space) + cls._add_array_symbol_name( + "LFRicIntegerScalarDataType", + f"{name}_{function_space_name}_to_" + f"{target_function_space_name}", + dims) + else: + raise InternalError( + f"Unexpected shape metadata. Found '{shape}' but expected " + f"one of {const.VALID_EVALUATOR_SHAPES}.") + + @classmethod + def _basis(cls, function_space): + '''Arguments associated with basis functions on the supplied function + space. + + :param str function_space: the current function space. + + ''' + cls._basis_or_diff_basis("basis", function_space) + + @classmethod + def _diff_basis(cls, function_space): + '''Arguments associated with differential basis functions on the + supplied function space. + + :param str function_space: the current function space. + + ''' + cls._basis_or_diff_basis("diff_basis", function_space) + + @classmethod + def _quad_rule(cls, shapes): + '''Quadrature information is required (gh_shape = + gh_quadrature_*). For each shape in the order specified in the + gh_shape metadata: + + Include integer, scalar arguments of kind i_def with intent in + that specify the extent of the basis/diff-basis arrays: + + If gh_shape is gh_quadrature_XYoZ then pass + np_xy_ and np_z_. + + If gh_shape is gh_quadrature_face/_edge then pass + nfaces/nedges_ and + np_xyz_. + + Include weights which are real arrays of kind r_def: + + If gh_quadrature_XYoZ pass in weights_xz_ + (rank one, extent np_xy_) and + weights_z_ (rank one, extent + np_z_). + + If gh_quadrature_face/_edge pass in + weights_xyz_ (rank two with extents + [np_xyz_, + nfaces/nedges_]). + + :param shapes: the metadata capturing the quadrature shapes + required by the kernel. + :type shapes: List[str] + + raises InternalError: if unexpected (quadrature) shape + metadata is found. + + ''' + return # ARPDBG + const = lfric.LFRicConstants() + for quad in shapes: + quad_name = quad.split('_')[-1] + if quad == "gh_quadrature_xyoz": + #cls.arg_info.extend([ + # f"npxy_{quad_name}", f"np_z_{quad_name}", + # f"weights_xz_{quad_name}", f"weights_z_{quad_name}"]) + #cls._arg_index += 4 + np_xy_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "np_xy") + np_z_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "np_z") + dims = [dimension_literal, ndf_symbol, np_xy_symbol, + np_z_symbol] + elif quad == "gh_quadrature_face": + #cls.arg_info.extend([ + # f"nfaces_{quad_name}", f"np_xyz_{quad_name}", + # f"weights_xyz_{quad_name}"]) + #cls._arg_index += 3 + np_xyz_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "np_xyz") + nfaces_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "nfaces") + dims = [dimension_literal, ndf_symbol, np_xyz_symbol, + nfaces_symbol] + elif quad == "gh_quadrature_edge": + #cls.arg_info.extend([ + # f"nedges_{quad_name}", f"np_xyz_{quad_name}", + # f"weights_xyz_{quad_name}"]) + #cls._arg_index += 3 + np_xyz_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "np_xyz") + nedges_symbol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", "nedges") + dims = [dimension_literal, ndf_symbol, np_xyz_symbol, + nedges_symbol] + else: + raise InternalError( + f"Unexpected shape metadata. Found '{quad}' but expected " + f"one of {const.VALID_QUADRATURE_SHAPES}.") + cls._add_array_symbol_name( + "LFRicIntegerScalarDataType", + f"{name}_{function_space_name}_qr_{shape.split('_')[-1]}", + dims) + + # pylint: disable=unidiomatic-typecheck + @classmethod + def _field_bcs_kernel(cls): + '''Fix for the field boundary condition kernel. Adds a boundary dofs + 2D integer array with intent in and kind i_def. The size of + the dimensions are ndf_, 2. Its default name + is "boundary_dofs_", where is the + default name of the field on the current function space. + + :raises InternalError: if the enforce_bc_kernel does not have + a single field argument on the any_space_1 function space. + + ''' + # Check that this kernel has a single field argument that is + # on the any_space_1 function space. + if len(cls._metadata.meta_args) != 1: + raise InternalError( + f"An enforce_bc_code kernel should have a single " + f"argument but found '{len(cls._metadata.meta_args)}'.") + meta_arg = cls._metadata.meta_args[0] + if not type(meta_arg) == lfric.kernel.FieldArgMetadata: + raise InternalError( + f"An enforce_bc_code kernel should have a single field " + f"argument but found '{type(meta_arg).__name__}'.") + if not meta_arg.function_space == "any_space_1": + raise InternalError( + f"An enforce_bc_code kernel should have a single field " + f"argument on the 'any_space_1' function space, but found " + f"'{meta_arg.function_space}'.") + + field_name = cls._field_name(meta_arg) + # 2d integer array + print("TO BE ADDED") + exit(1) + # cls._add_lfric_symbol_name("xxx", f"boundary_dofs_{field_name}") + + @classmethod + def _operator_bcs_kernel(cls): + '''Fix for the operator boundary condition kernel. Adds a boundary + dofs 2D integer array with intent in and kind i_def. The size + of the dimensions are ndf_, 2. Its default + name is "boundary_dofs_", where is + the default name of the field on the current function space. + + :raises InternalError: if the enforce_operator_bc_kernel does + not have a single lma operator argument. + + ''' + # Check that this kernel has a single LMA argument. + if len(cls._metadata.meta_args) != 1: + raise InternalError( + f"An enforce_operator_bc_code kernel should have a single " + f"argument but found '{len(cls._metadata.meta_args)}'.") + meta_arg = cls._metadata.meta_args[0] + if type(meta_arg) is not lfric.OperatorArgMetadata: + raise InternalError( + f"An enforce_operator_bc_code kernel should have a single " + f"lma operator argument but found " + f"'{type(meta_arg).__name__}'.") + + lma_operator_name = cls._operator_name(meta_arg) + print("TO BE ADDED") + exit(1) + # cls._add_lfric_symbol_name("", f"boundary_dofs_{lma_operator_name}") + + @classmethod + def _stencil_2d_unknown_extent(cls, meta_arg): + '''The field entry has a stencil access of type cross2d so add a 1D + integer array of extent 4 and kind i_def stencil-size argument + with intent in. The default name is + "_stencil_size", where is the default + name of the field with this stencil. This will supply the + number of cells in each branch of the stencil. + + :param meta_arg: the metadata associated with a field argument + with a cross2d stencil access. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata` + + ''' + field_name = cls._field_name(meta_arg) + # Array + print("TO BE DONE") + exit(1) + # cls._add_lfric_symbol_name(f"{field_name}_stencil_size") + + @classmethod + def _stencil_2d_max_extent(cls, meta_arg): + '''The field entry has a stencil access of type cross2d so add an + integer of kind i_def and intent in for the max branch + length. The default name is "_max_branch_length", + where is the default name of the field with this + stencil. This is used in defining the dimensions of the + stencil dofmap array and is required due to the varying length + of the branches of the stencil when used on planar meshes. + + :param meta_arg: the metadata associated with a field argument + with a cross2d stencil access. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata` + + ''' + field_name = cls._field_name(meta_arg) + cls._add_lfric_symbol_name( + "LFRicIntegerScalar", f"{field_name}_max_branch_length") + + @classmethod + def _stencil_unknown_extent(cls, meta_arg): + '''The field entry has a stencil access so add an integer stencil-size + argument with intent in and kind i_def. The default name is + "_stencil_size", where is the default + name of the field with this stencil. This argument will + contain the number of cells in the stencil. + + :param meta_arg: the metadata associated with a field argument + with a stencil access. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata` + + ''' + field_name = cls._field_name(meta_arg) + cls._add_lfric_symbol_name( + "LFRicIntegerScalar", f"{field_name}_stencil_size") + + @classmethod + def _stencil_unknown_direction(cls, meta_arg): + '''The field entry stencil access is of type XORY1D so add an + additional integer direction argument of kind i_def and with + intent in, with default name "_direction", where + is the default name of the field with this + stencil. + + :param meta_arg: the metadata associated with a field argument + with a xory1d stencil access. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata` + + ''' + field_name = cls._field_name(meta_arg) + cls._add_lfric_symbol_name("LFRicIntegerScalar", + f"{field_name}_direction") + + @classmethod + def _stencil_2d(cls, meta_arg): + '''Stencil information that is passed from the Algorithm layer if the + stencil is 'cross2d'. Add a 3D stencil dofmap array of type + integer, kind i_def and intent in. The dimensions are + (number-of-dofs-in-cell, max-branch-length, 4). The default + name is "_stencil_dofmap", where is + the default name of the field with this stencil. + + :param meta_arg: the metadata associated with a field argument + with a stencil access. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata` + + ''' + field_name = cls._field_name(meta_arg) + cls._add_lfric_symbol_name( + "LFRicIntegerScalar", f"{field_name}_stencil_dofmap") + + @classmethod + def _stencil(cls, meta_arg): + '''Stencil information that is passed from the Algorithm layer if the + stencil is not 'cross2d'. Add a 2D stencil dofmap array of + type integer, kind i_def and intent in. The dimensions are + (number-of-dofs-in-cell, stencil-size). The default name is + "_stencil_dofmap, where is the + default name of the field with this stencil. + + :param meta_arg: the metadata associated with a field argument + with a stencil access. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata` + + ''' + field_name = cls._field_name(meta_arg) + print("TO BE DONE") + exit(1) + # cls._add_lfric_symbol_name(f"{field_name}_stencil_dofmap") + + @classmethod + def _banded_dofmap(cls, function_space, cma_operator): + '''Adds a banded dofmap for the provided function space and cma + operator when there is an assembly cma kernel. + + Include the column-banded dofmap, the list of offsets for the + to/from-space. This is an integer array of rank 2 and kind + i_def with intent in. The first dimension is + "ndf_" and the second is nlayers. Its + default name is 'cbanded_map_''_' + + :param str function_space: the function space for this banded + dofmap. + :param cma_operator: the cma operator metadata associated with + this banded dofmap. + :type cma_operator: :py:class:`psyclone.domain.lfric.kernel. + ColumnwiseOperatorArgMetadata` + + ''' + function_space_name = cls._function_space_name(function_space) + name = cls._cma_operator_name(cma_operator) + ndf_name = cls._ndf_name(function_space) + ndf_symbol = cls._get_or_create_lfric_symbol( + "NumberOfDofsDataSymbol", ndf_name) + nlayers = cls._get_or_create_lfric_symbol("MeshHeightDataSymbol", + "nlayers") + cls._add_lfric_symbol_name( + "ColumnBandedDofMapDataSymbol", + f"cbanded_map_{function_space_name}_{name}", + dims=[ndf_symbol, nlayers]) + + @classmethod + def _indirection_dofmap(cls, function_space, cma_operator): + '''Adds an indirection dofmap for the provided function space and cma + operator when there is an apply cma kernel. + + Include the indirection map for the 'to' function space of the + supplied CMA operator. This is a rank-1 integer array of kind + i_def and intent in with extent nrow. Its default name is + 'cma_indirection_map_''_'. + + If the from-space of the operator is not the same as the + to-space then include the indirection map for the 'from' + function space of the CMA operator. This is a rank-1 integer + array of kind i_def and intent in with extent ncol. Its + default name is + 'cma_indirection_map_''_'. + + Note, this method will not be called for the from space if the + to and from function spaces are the same so there is no need + to explicitly check. + + :param str function_space: the function space for this + indirection dofmap. + :param cma_operator: the cma operator metadata associated with + this indirection dofmap. + :type cma_operator: :py:class:`psyclone.domain.lfric.kernel. + ColumnwiseOperatorArgMetadata` + + ''' + function_space_name = cls._function_space_name(function_space) + name = cls._cma_operator_name(cma_operator) + + if function_space == cma_operator.function_space_to: + nrow = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"nrow_{name}") + dims = [nrow] + else: + ncol = cls._get_or_create_lfric_symbol( + "LFRicIntegerScalarDataSymbol", f"ncol_{name}") + dims = [ncol] + cls._add_lfric_symbol_name( + "CMADofMapDataSymbol", + f"cma_dofmap_{function_space_name}", dims=dims) + + @classmethod + def _create_datatype(cls, class_name): + ''' xxx ''' + required_class = lfric.LFRicTypes(class_name) + symbol = required_class() + return symbol + + @classmethod + def _create_array_symbol(cls, scalar_type_name, symbol_name, dims, + access=symbols.ArgumentInterface.Access.READ): + ''' xxx ''' + scalar_type = cls._create_datatype(scalar_type_name) + array_args = [nodes.Reference(symbol) if not isinstance(symbol, + nodes.Literal) + else symbol for symbol in dims] + array_type = symbols.ArrayType(scalar_type, array_args) + interface = symbols.ArgumentInterface(access) + array_symbol = symbols.DataSymbol( + symbol_name, array_type, interface=interface) + cls._add_to_symbol_table(array_symbol) + + return array_symbol + + @classmethod + def _add_array_symbol_name(cls, scalar_type_name, symbol_name, dims, + access=symbols.ArgumentInterface.Access.READ): + '''xxx''' + symbol = cls._create_array_symbol( + scalar_type_name, symbol_name, dims, access=access) + cls._append_to_arg_list(symbol) + + @classmethod + def _create_lfric_symbol(cls, class_name, symbol_name, dims=None, + access=symbols.ArgumentInterface.Access.READ): + ''' xxx ''' + required_class = lfric.LFRicTypes(class_name) + if dims: + array_args = [nodes.Reference(symbol) for symbol in dims] + symbol = required_class(symbol_name, array_args) + else: + symbol = required_class(symbol_name) + symbol.interface = symbols.ArgumentInterface(access) + cls._add_to_symbol_table(symbol) + return symbol + + @classmethod + def _add_lfric_symbol_name(cls, class_name, symbol_name, dims=None, + access=symbols.ArgumentInterface.Access.READ): + '''Utility function to create an LFRic-PSyIR symbol of type class_name + and name symbol_name and add it to the symbol table. + + :param str class_name: the name of the class to be created. + :param str symbol_name: the name of the symbol to be created. + + ''' + symbol = cls._create_lfric_symbol( + class_name, symbol_name, dims=dims, access=access) + cls._append_to_arg_list(symbol) + + @classmethod + def _get_or_create_lfric_symbol( + cls, class_name, symbol_name, + access=symbols.ArgumentInterface.Access.READ): + ''' xxx ''' + try: + symbol = cls._info.lookup_with_tag(symbol_name) + except KeyError: + symbol = cls._create_lfric_symbol( + class_name, symbol_name, access=access) + return symbol + + @classmethod + def _add_to_symbol_table(cls, symbol): + ''' xxx ''' + #if isinstance(symbol.interface, symbols.ArgumentInterface): + # cls._info.append_argument(symbol, tag=symbol.name) + #else: + cls._info.add(symbol, tag=symbol.name) + if isinstance(symbol.datatype.precision, symbols.DataSymbol): + cls._add_precision_symbol(symbol.datatype.precision) + + @classmethod + def _append_to_arg_list(cls, symbol): + ''' xxx ''' + cls._info._argument_list.append(symbol) + # TODO this should really use SymbolTable.append_argument() but that + # *adds* the symbol to the table too. + # symbol_table.specify_argument_list([arg1]) + + @classmethod + def _field_name(cls, meta_arg): + '''Utility function providing the default field name from its meta_arg + metadata. + + :param meta_arg: metadata describing a field argument. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.FieldArgMetadata` + + :returns: the default name for this meta_arg field. + :rtype: str + + ''' + datatype = meta_arg.datatype[3:4] + meta_arg_index = cls._metadata.meta_args.index(meta_arg) + fs_name = cls._function_space_name(meta_arg.function_space) + return f"{datatype}field_{meta_arg_index+1}_{fs_name}" + + @classmethod + def _operator_name(cls, meta_arg): + '''Utility function providing the default field name from its meta_arg + metadata. + + :param meta_arg: metadata describing an lma operator argument. + :type meta_arg: + :py:class:`psyclone.domain.lfric.kernel.OperatorArgMetadata` + + :returns: the default name for this meta_arg field. + :rtype: str + + ''' + meta_arg_index = cls._metadata.meta_args.index(meta_arg) + return f"op_{meta_arg_index+1}" + + @classmethod + def _cma_operator_name(cls, meta_arg): + '''Utility function providing the default cma operator name from its + meta_arg metadata. + + :param meta_arg: metadata describing a cma operator argument. + :type meta_arg: :py:class:`psyclone.domain.lfric.kernel. + ColumnwiseOperatorArgMetadata` + + :returns: the default name for this meta_arg cma operator. + :rtype: str + + ''' + meta_arg_index = cls._metadata.meta_args.index(meta_arg) + return f"cma_op_{meta_arg_index+1}" + + @classmethod + def _function_space_name(cls, function_space): + '''Shortens the function space name if it is any_space_* or + any_discontinuous_space_*. + + :param str function_space: the function space name. + + :returns: a shortened function space name. + :rtype: str + + ''' + if "any_space_" in function_space: + return f"aspc{function_space[10:]}" + if "any_discontinuous_space_" in function_space: + return f"adspc{function_space[24:]}" + return function_space + + @classmethod + def _undf_name(cls, function_space): + ''' xxx ''' + function_space_name = cls._function_space_name(function_space) + return f"undf_{function_space_name}" + + @classmethod + def _ndf_name(cls, function_space): + ''' xxx ''' + function_space_name = cls._function_space_name(function_space) + return f"ndf_{function_space_name}" + + @classmethod + def _dofmap_name(cls, function_space): + ''' xxx ''' + function_space_name = cls._function_space_name(function_space) + return f"dofmap_{function_space_name}" + + @classmethod + def _fullmap_name(cls, function_space): + ''' xxx ''' + function_space_name = cls._function_space_name(function_space) + return f"full_map_{function_space_name}" diff --git a/src/psyclone/domain/lfric/kernel/lfric_kernel_metadata.py b/src/psyclone/domain/lfric/kernel/lfric_kernel_metadata.py index 153e19f003..948a8dfac3 100644 --- a/src/psyclone/domain/lfric/kernel/lfric_kernel_metadata.py +++ b/src/psyclone/domain/lfric/kernel/lfric_kernel_metadata.py @@ -305,6 +305,13 @@ def _validate_general_purpose_kernel(self): f"*cell_column' should only have meta_arg arguments " f"of type field, field vector, LMA operator or scalar" f", but found '{meta_arg.check_name}'")) + if (type(meta_arg) is ScalarArgMetadata and + meta_arg.access != 'gh_read'): + raise ParseError(self._validation_error_str( + f"Scalar arguments to general-purpose kernels with " + f"'operates_on == cell_column' must be read-only but " + f"found '{meta_arg.datatype}' scalar with " + f"'{meta_arg.access}' access")) # TODO issue #1953: constraints when operates_on == dofs # 1: They must have one and only one modified (i.e. written @@ -713,7 +720,7 @@ def create_from_psyir(symbol): def create_from_fparser2(fparser2_tree): '''Create an instance of this class from an fparser2 tree. - :param fparser2_tree: fparser2 tree containing the metadata \ + :param fparser2_tree: fparser2 tree containing the metadata for an LFRic Kernel. :type fparser2_tree: \ :py:class:`fparser.two.Fortran2003.Derived_Type_Ref` @@ -722,7 +729,7 @@ def create_from_fparser2(fparser2_tree): :rtype: :py:class:`psyclone.domain.lfric.kernel.psyir.\ LFRicKernelMetadata` - :raises ParseError: if one of the meta_args entries is an \ + :raises ParseError: if one of the meta_args entries is an unexpected type. :raises ParseError: if the metadata type does not extend kernel_type. @@ -779,6 +786,9 @@ def create_from_fparser2(fparser2_tree): kernel_metadata.procedure_name = \ LFRicKernelMetadata._get_procedure_name(fparser2_tree) + # Validate the metadata. + kernel_metadata.validate() + return kernel_metadata def lower_to_psyir(self): diff --git a/src/psyclone/domain/lfric/lfric_types.py b/src/psyclone/domain/lfric/lfric_types.py index 305bda150f..b94fe29594 100644 --- a/src/psyclone/domain/lfric/lfric_types.py +++ b/src/psyclone/domain/lfric/lfric_types.py @@ -457,6 +457,10 @@ class Array: ["fs_from", "fs_to"]), Array("DofMap", "LFRicIntegerScalarDataType", ["number of dofs"], ["fs"]), + Array("ColumnBandedDofMap", "LFRicIntegerScalarDataType", + ["number of dofs", "number of layers"], ["fs"]), + Array("CMADofMap", "LFRicIntegerScalarDataType", + ["number of rows or columns"], ["fs_from", "fs_to"]), Array("BasisFunctionQrXyoz", "LFRicRealScalarDataType", [LFRicTypes("LFRicDimension"), "number of dofs", "number of qr points in xy", diff --git a/src/psyclone/domain/lfric/metadata_to_arguments_rules.py b/src/psyclone/domain/lfric/metadata_to_arguments_rules.py index f363db908d..3ef11addba 100644 --- a/src/psyclone/domain/lfric/metadata_to_arguments_rules.py +++ b/src/psyclone/domain/lfric/metadata_to_arguments_rules.py @@ -84,12 +84,15 @@ def mapping(cls, metadata, info=None): to add to an existing object if required. :param metadata: the kernel metadata. - :type metadata: \ + :type metadata: py:class:`psyclone.domain.lfric.kernel.LFRicKernelMetadata` - :param info: optional object to initialise the _info \ + :param info: optional object to initialise the _info variable. Defaults to None. :type info: :py:class:`Object` + :returns: something + :rtype: + ''' cls._initialise(info) cls._metadata = metadata diff --git a/src/psyclone/gen_kernel_stub.py b/src/psyclone/gen_kernel_stub.py index 60c1eff5e8..1d08da73d7 100644 --- a/src/psyclone/gen_kernel_stub.py +++ b/src/psyclone/gen_kernel_stub.py @@ -41,11 +41,13 @@ call a stub) when presented with Kernel Metadata. ''' -from __future__ import print_function import os +import re import fparser -from psyclone.domain.lfric import LFRicKern, LFRicKernMetadata +from psyclone.domain.lfric import ( + LFRicKern, LFRicKernMetadata, FormalKernelArgsFromMetadata) +from psyclone.domain.lfric.kernel import LFRicKernelMetadata from psyclone.errors import GenerationError from psyclone.parse.utils import ParseError from psyclone.configuration import Config, LFRIC_API_NAMES @@ -60,9 +62,9 @@ def generate(filename, api=""): Kernel Metadata must be presented in the standard Kernel format. - :param str filename: the name of the file for which to create a \ + :param str filename: the name of the file for which to create a kernel stub for. - :param str api: the name of the API for which to create a kernel \ + :param str api: the name of the API for which to create a kernel stub. Must be one of the supported stub APIs. :returns: root of fparser1 parse tree for the stub routine. @@ -83,18 +85,31 @@ def generate(filename, api=""): if not os.path.isfile(filename): raise IOError(f"Kernel stub generator: File '{filename}' not found.") - # Drop cache - fparser.one.parsefortran.FortranParser.cache.clear() - fparser.logging.disable(fparser.logging.CRITICAL) + from psyclone.psyir.frontend import fortran + from psyclone.psyir import nodes + from psyclone.psyir.symbols import DataTypeSymbol + from psyclone.errors import InternalError + freader = fortran.FortranReader() try: - ast = fparser.api.parse(filename, ignore_comments=False) - - except (fparser.common.utils.AnalyzeError, AttributeError) as error: + kern_psyir = freader.psyir_from_file(filename) + except ValueError as err: raise ParseError(f"Kernel stub generator: Code appears to be invalid " - f"Fortran: {error}.") + f"Fortran: {err}.") from err - metadata = LFRicKernMetadata(ast) - kernel = LFRicKern() - kernel.load_meta(metadata) + table = kern_psyir.children[0].symbol_table + for sym in table.symbols: + if isinstance(sym, DataTypeSymbol) and not sym.is_import: + break + else: + raise InternalError("No DataTypeSymbol found.") - return kernel.gen_stub + metadata = LFRicKernelMetadata.create_from_psyir(sym) + new_table = FormalKernelArgsFromMetadata.mapping(metadata) + mod_name = re.sub(r"_type$", r"_mod", sym.name) + new_container = nodes.Container(mod_name) + # Add the metadata + new_container.symbol_table.add(sym) + kern_name = metadata.procedure_name + new_routine = nodes.Routine.create(kern_name, new_table, []) + new_container.addchild(new_routine) + return new_container diff --git a/src/psyclone/kernel_tools.py b/src/psyclone/kernel_tools.py index 544a9fc37f..3c297d0e18 100644 --- a/src/psyclone/kernel_tools.py +++ b/src/psyclone/kernel_tools.py @@ -161,20 +161,22 @@ def run(args): if args.gen == "alg": # Generate algorithm code. if api in LFRIC_API_NAMES: - alg_psyir = LFRicAlg().create_from_kernel("test_alg", - args.filename) - code = FortranWriter()(alg_psyir) + psyir = LFRicAlg().create_from_kernel("test_alg", + args.filename) else: print(f"Algorithm generation from kernel metadata is " f"not yet implemented for API '{api}'", file=sys.stderr) sys.exit(1) elif args.gen == "stub": # Generate kernel stub - code = gen_kernel_stub.generate(args.filename, api=api) + psyir = gen_kernel_stub.generate(args.filename, api=api) else: raise InternalError(f"Expected -gen option to be one of " f"{list(GEN_MODES.keys())} but got {args.gen}") + # Generate the output Fortran. + code = FortranWriter()(psyir) + except (IOError, ParseError, GenerationError, RuntimeError) as error: print("Error:", error, file=sys.stderr) sys.exit(1) diff --git a/src/psyclone/psyir/frontend/fortran.py b/src/psyclone/psyir/frontend/fortran.py index 481690b9dd..498e3f539c 100644 --- a/src/psyclone/psyir/frontend/fortran.py +++ b/src/psyclone/psyir/frontend/fortran.py @@ -44,7 +44,7 @@ from fparser.two import Fortran2003, pattern_tools from fparser.two.parser import ParserFactory from fparser.two.symbol_table import SYMBOL_TABLES -from fparser.two.utils import NoMatchError +from fparser.two.utils import NoMatchError, FortranSyntaxError from psyclone.configuration import Config from psyclone.psyir.frontend.fparser2 import Fparser2Reader from psyclone.psyir.nodes import Schedule, Assignment, Routine @@ -228,6 +228,8 @@ def psyir_from_file(self, file_path): :returns: PSyIR representing the provided Fortran file. :rtype: :py:class:`psyclone.psyir.nodes.Node` + :raises ValueError: if the file cannot be parsed. + ''' SYMBOL_TABLES.clear() @@ -242,7 +244,12 @@ def psyir_from_file(self, file_path): include_dirs=Config.get().include_paths, ignore_comments=self._ignore_comments) reader.set_format(FortranFormat(self._free_form, False)) - parse_tree = self._parser(reader) + try: + parse_tree = self._parser(reader) + except FortranSyntaxError as err: + raise ValueError( + f"File '{file_path}' could not be parsed, does it" + f" contain valid Fortran? Error was:\n{err}") from err _, filename = os.path.split(file_path) psyir = self._processor.generate_psyir(parse_tree, filename) diff --git a/src/psyclone/tests/domain/lfric/formal_kernel_args_from_metadata_test.py b/src/psyclone/tests/domain/lfric/formal_kernel_args_from_metadata_test.py new file mode 100644 index 0000000000..49d4a4c722 --- /dev/null +++ b/src/psyclone/tests/domain/lfric/formal_kernel_args_from_metadata_test.py @@ -0,0 +1,629 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2023-2024, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Author: R. W. Ford, STFC Daresbury Lab + +'''This module tests the FormalKernelArgsFromMetadata class.''' + +from collections import OrderedDict + +import pytest + +from psyclone.domain import lfric +from psyclone.errors import InternalError +from psyclone.psyir import symbols + + +def call_method(method_name, *args, metadata=None): + '''Utility function that initialises the FormalKernelArgsFromMetadata + class with optional metadata and a symbol table, then calls the + class method specified in argument 'method_name' with the + arguments specified in argument *args and returns the class. + + :param str method_name: the name of the method to test. + :param metadata: optional metadata required by some methods. + :type metadata: Optional[ + :py:class:`psyclone.domain.lfric.kernel.LFRicKernelMetadata`] + + :returns: a FormalKernelArgsFromMetadata class after the supplied + class method has been called. + :rtype: :py:class:`psyclone.domain.lfric.FormalKernelArgsFromMetadata` + + ''' + cls = lfric.FormalKernelArgsFromMetadata + cls._info = symbols.SymbolTable() + cls._metadata = metadata + getattr(cls, method_name)(*args) + return cls + + +def check_single_symbol( + method_name, datasymbol_name, symbol_name, *args, metadata=None, + check_unchanged=False): + '''Utility function that calls the method in argument 'method_name' + with the arguments stored in argument '*args' and checks that as a + result a symbol with name 'symbol_name' of type 'datasymbol_name' + is created. This function tests methods where a single symbol is + created. + + :param str method_name: the name of the method to test. + :param str datasymbol_name: the name of the expected symbol type + that is created. + :param str symbol_name: the expected name of the created symbol. + :param metadata: optional metadata required by some methods. + :type metadata: Optional[ + :py:class:`psyclone.domain.lfric.kernel.LFRicKernelMetadata`] + *** check unchanged *** + + :returns: xxx + + ''' + cls = call_method(method_name, *args, metadata=metadata) + lfric_class = lfric.LFRicTypes(datasymbol_name) + symbol = cls._info.lookup(symbol_name) + # pylint gets confused here + # pylint: disable=isinstance-second-argument-not-valid-type + assert isinstance(symbol, lfric_class) + # pylint: enable=isinstance-second-argument-not-valid-type + assert len(cls._info._argument_list) == 1 + assert cls._info._argument_list[0] is symbol + if check_unchanged: + # Check that the symbol remains unchanged if it has already been + # declared. + symbol_id = id(symbol) + # Reset the argument list as this method will have added it + # both to the symbol table and to the argument list whereas if + # it had already been declared by another method it will have + # only been added to the symbol table. + cls._info._argument_list = [] + getattr(cls, method_name)(*args) + symbol = cls._info.lookup_with_tag(symbol_name) + assert symbol_id == id(symbol) + assert len(cls._info._argument_list) == 1 + assert cls._info._argument_list[0] is symbol + + return cls + + +def check_symbols(cls, symbol_dict): + ''' xxx ''' + for symbol_name, lfric_class in symbol_dict.items(): + assert isinstance(cls._info.lookup(symbol_name), lfric_class) + + +def check_arg_symbols(cls, symbol_dict): + ''' xxx ''' + check_symbols(cls, symbol_dict) + assert len(cls._info._argument_list) == len(symbol_dict) + for idx, symbol_name in enumerate(symbol_dict.keys()): + assert cls._info._argument_list[idx].name == symbol_name + + +def check_common_cma_symbols(fs1, fs2): + ''' xxx ''' + operator_meta_arg = lfric.kernel.ColumnwiseOperatorArgMetadata( + "GH_REAL", "GH_WRITE", fs1, fs2) + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[operator_meta_arg]) + metadata.validate() + cls = call_method("_cma_operator", operator_meta_arg, metadata=metadata) + return cls + + +def test_fkafm_initialise(): + '''Test the _initialise() method.''' + cls = lfric.FormalKernelArgsFromMetadata + # No table supplied - one is created. + cls._initialise() + assert isinstance(cls._info, symbols.SymbolTable) + # Table supplied. + new_table = symbols.SymbolTable() + cls._initialise(new_table) + assert cls._info is new_table + # Wrong type of argument. + with pytest.raises(TypeError) as err: + cls._initialise("table") + assert ("Expecting the optional 'symbol_table' argument to be a " + "SymbolTable but found str" in str(err.value)) + + +def test_cell_position(): + ''' Test _cell_position method. ''' + check_single_symbol("_cell_position", "CellPositionDataSymbol", "cell") + + +def test_mesh_height(): + ''' Test _mesh_height method. ''' + check_single_symbol("_mesh_height", "MeshHeightDataSymbol", "nlayers") + + +def test_mesh_ncell2d_no_halos(): + ''' Test _mesh_ncell2d_no_halos method. ''' + check_single_symbol( + "_mesh_ncell2d_no_halos", "LFRicIntegerScalarDataSymbol", + "ncell_2d_no_halos") + + +def test_mesh_ncell2d(): + ''' Test _mesh_ncell2d method. ''' + symbol_name = "ncell_2d" + cls = check_single_symbol( + "_mesh_ncell2d", "LFRicIntegerScalarDataSymbol", symbol_name, + check_unchanged=True) + + +def test_cell_map(): + ''' Test _cell_map method. ''' + cls = call_method("_cell_map") + lfric_class = lfric.LFRicTypes("LFRicIntegerScalarDataSymbol") + # Symbols added to the symbol table and to the argument list. + check_arg_symbols(cls, OrderedDict( + [("cell_map", symbols.DataSymbol), ("ncell_f_per_c_x", lfric_class), + ("ncell_f_per_c_y", lfric_class), ("ncell_f", lfric_class)])) + + +def test_scalar(): + ''' Test _scalar method. ''' + # At least one field arg is required for the metadata to be valid + # even though we only want to test the scalar metadata. + field_meta_arg = lfric.kernel.FieldArgMetadata("GH_REAL", "GH_WRITE", "W3") + scalar_meta_arg = lfric.kernel.ScalarArgMetadata("GH_REAL", "GH_READ") + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[field_meta_arg, scalar_meta_arg]) + metadata.validate() + check_single_symbol( + "_scalar", "LFRicRealScalarDataSymbol", "rscalar_2", scalar_meta_arg, + metadata=metadata) + + +def test_field(): + ''' Test _field method. ''' + field_meta_arg = lfric.kernel.FieldArgMetadata("GH_REAL", "GH_WRITE", "W3") + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[field_meta_arg]) + metadata.validate() + cls = check_single_symbol( + "_field", "RealFieldDataSymbol", "rfield_1_w3", field_meta_arg, + metadata=metadata) + lfric_class = lfric.LFRicTypes("NumberOfUniqueDofsDataSymbol") + # Symbols added to the symbol table but not to the argument list. + check_symbols(cls, {"undf_w3": lfric_class}) + + +def test_field_vector(): + ''' Test _field_vector method. ''' + field_meta_arg = lfric.kernel.FieldVectorArgMetadata( + "GH_REAL", "GH_WRITE", "W3", "3") + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[field_meta_arg]) + metadata.validate() + cls = call_method("_field_vector", field_meta_arg, metadata=metadata) + lfric_class = lfric.LFRicTypes("RealFieldDataSymbol") + # Symbols added to the symbol table and to the argument list. + check_arg_symbols(cls, OrderedDict( + [("rfield_1_w3_v1", lfric_class), ("rfield_1_w3_v2", lfric_class), + ("rfield_1_w3_v3", lfric_class)])) + lfric_class = lfric.LFRicTypes("NumberOfUniqueDofsDataSymbol") + # Symbols added to the symbol table but not to the argument list. + check_symbols(cls, {"undf_w3": lfric_class}) + + +def test_operator(): + ''' Test _operator method. ''' + operator_meta_arg = lfric.kernel.OperatorArgMetadata( + "GH_REAL", "GH_WRITE", "W3", "W2") + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[operator_meta_arg]) + metadata.validate() + cls = call_method("_operator", operator_meta_arg, metadata=metadata) + lfric_int_class = lfric.LFRicTypes("LFRicIntegerScalarDataSymbol") + lfric_op_class = lfric.LFRicTypes("OperatorDataSymbol") + # Symbols added to the symbol table and to the argument list. + check_arg_symbols(cls, OrderedDict( + [("op_1_ncell_3d", lfric_int_class), ("op_1", lfric_op_class)])) + lfric_dofs_class = lfric.LFRicTypes("NumberOfDofsDataSymbol") + # Symbols added to the symbol table but not to the argument list. + check_symbols( + cls, {"ndf_w3": lfric_dofs_class, "ndf_w2": lfric_dofs_class}) + + +def test_cma_operator(): + ''' Test _cma_operator method. ''' + # to/from function spaces differ so there is an additional + # argument. + cls = check_common_cma_symbols("W3", "W2") + lfric_op_class = lfric.LFRicTypes("OperatorDataSymbol") + lfric_int_class = lfric.LFRicTypes("LFRicIntegerScalarDataSymbol") + check_arg_symbols(cls, OrderedDict( + [("cma_op_1", lfric_op_class), + ("nrow_cma_op_1", lfric_int_class), + ("ncol_cma_op_1", lfric_int_class), + ("bandwidth_cma_op_1", lfric_int_class), + ("alpha_cma_op_1", lfric_int_class), + ("beta_cma_op_1", lfric_int_class), + ("gamma_m_cma_op_1", lfric_int_class), + ("gamma_p_cma_op_1", lfric_int_class)])) + check_symbols(cls, {"ncell_2d": lfric_int_class}) + + # to/from function spaces are the same so there is no additional + # argument. + cls = check_common_cma_symbols("W3", "W3") + with pytest.raises(KeyError): + cls._info.lookup("ncol_cma_op_1") + check_arg_symbols(cls, OrderedDict( + [("cma_op_1", lfric_op_class), + ("nrow_cma_op_1", lfric_int_class), + ("bandwidth_cma_op_1", lfric_int_class), + ("alpha_cma_op_1", lfric_int_class), + ("beta_cma_op_1", lfric_int_class), + ("gamma_m_cma_op_1", lfric_int_class), + ("gamma_p_cma_op_1", lfric_int_class)])) + check_symbols(cls, {"ncell_2d": lfric_int_class}) + + +# pylint: disable=too-many-statements +def test_ref_element_properties(monkeypatch): + ''' Test _ref_element_properties method. ''' + lfric_qr_xy_class = lfric.LFRicTypes("NumberOfQrPointsInXyDataSymbol") + lfric_qr_z_class = lfric.LFRicTypes("NumberOfQrPointsInZDataSymbol") + lfric_qr_faces_class = lfric.LFRicTypes( + "NumberOfQrPointsInFacesDataSymbol") + + # Horizontal + meta_ref_element = [ + lfric.kernel.MetaRefElementArgMetadata("normals_to_horizontal_faces")] + cls = call_method("_ref_element_properties", meta_ref_element) + check_arg_symbols(cls, OrderedDict([ + ("nfaces_re_h", lfric_qr_xy_class), + ("normals_to_horizontal_faces", symbols.DataSymbol)])) + symbol = cls._info.lookup("normals_to_horizontal_faces") + assert symbol.is_array + assert len(symbol.datatype.shape) == 2 + assert symbol.datatype.shape[0].upper.value == "3" + assert symbol.datatype.shape[1].upper.symbol.name == "nfaces_re_h" + + # Vertical + meta_ref_element = [ + lfric.kernel.MetaRefElementArgMetadata("normals_to_vertical_faces")] + cls = call_method("_ref_element_properties", meta_ref_element) + check_arg_symbols(cls, OrderedDict([ + ("nfaces_re_v", lfric_qr_z_class), + ("normals_to_vertical_faces", symbols.DataSymbol)])) + symbol = cls._info.lookup("normals_to_vertical_faces") + assert symbol.is_array + assert len(symbol.datatype.shape) == 2 + assert symbol.datatype.shape[0].upper.value == "3" + assert symbol.datatype.shape[1].upper.symbol.name == "nfaces_re_v" + + # General + meta_ref_element = [ + lfric.kernel.MetaRefElementArgMetadata("normals_to_faces")] + cls = call_method("_ref_element_properties", meta_ref_element) + check_arg_symbols(cls, OrderedDict([ + ("nfaces_re", lfric_qr_faces_class), + ("normals_to_faces", symbols.DataSymbol)])) + symbol = cls._info.lookup("normals_to_faces") + assert symbol.is_array + assert len(symbol.datatype.shape) == 2 + assert symbol.datatype.shape[0].upper.value == "3" + assert symbol.datatype.shape[1].upper.symbol.name == "nfaces_re" + + # All + meta_ref_element = [ + lfric.kernel.MetaRefElementArgMetadata("normals_to_horizontal_faces"), + lfric.kernel.MetaRefElementArgMetadata("normals_to_vertical_faces"), + lfric.kernel.MetaRefElementArgMetadata("normals_to_faces")] + cls = call_method("_ref_element_properties", meta_ref_element) + check_arg_symbols(cls, OrderedDict([ + ("nfaces_re_h", lfric_qr_xy_class), + ("nfaces_re_v", lfric_qr_z_class), + ("nfaces_re", lfric_qr_faces_class), + ("normals_to_horizontal_faces", symbols.DataSymbol), + ("normals_to_vertical_faces", symbols.DataSymbol), + ("normals_to_faces", symbols.DataSymbol)])) + symbol = cls._info.lookup("normals_to_horizontal_faces") + assert symbol.is_array + assert len(symbol.datatype.shape) == 2 + assert symbol.datatype.shape[0].upper.value == "3" + assert symbol.datatype.shape[1].upper.symbol.name == "nfaces_re_h" + symbol = cls._info.lookup("normals_to_vertical_faces") + assert symbol.is_array + assert len(symbol.datatype.shape) == 2 + assert symbol.datatype.shape[0].upper.value == "3" + assert symbol.datatype.shape[1].upper.symbol.name == "nfaces_re_v" + symbol = cls._info.lookup("normals_to_faces") + assert symbol.is_array + assert len(symbol.datatype.shape) == 2 + assert symbol.datatype.shape[0].upper.value == "3" + assert symbol.datatype.shape[1].upper.symbol.name == "nfaces_re" + + # Exception + meta_ref_element = [ + lfric.kernel.MetaRefElementArgMetadata("normals_to_faces")] + monkeypatch.setattr(meta_ref_element[0], "_reference_element", "invalid") + with pytest.raises(InternalError) as info: + _ = call_method("_ref_element_properties", meta_ref_element) + assert ("Unsupported reference element property 'invalid' found." + in str(info.value)) + + +def test_mesh_properties(monkeypatch): + ''' Test _mesh_properties method. ''' + # nfaces_re_h already passed + field_meta_arg = lfric.kernel.FieldArgMetadata("GH_REAL", "GH_WRITE", "W3") + meta_ref_element = lfric.kernel.MetaRefElementArgMetadata( + "normals_to_horizontal_faces") + meta_mesh_arg = lfric.kernel.MetaMeshArgMetadata("adjacent_face") + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[field_meta_arg], + meta_ref_element=[meta_ref_element], meta_mesh=[meta_mesh_arg]) + metadata.validate() + cls = call_method("_mesh_properties", [meta_mesh_arg], metadata=metadata) + check_arg_symbols(cls, OrderedDict([ + ("adjacent_face", symbols.DataSymbol)])) + symbol = cls._info.lookup("adjacent_face") + assert symbol.is_array + assert len(symbol.datatype.shape) == 1 + assert symbol.datatype.shape[0].upper.symbol.name == "nfaces_re_h" + + # nfaces_re_h not yet passed + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[field_meta_arg], + meta_mesh=[meta_mesh_arg]) + metadata.validate() + cls = call_method("_mesh_properties", [meta_mesh_arg], metadata=metadata) + lfric_qr_xy_class = lfric.LFRicTypes("NumberOfQrPointsInXyDataSymbol") + check_arg_symbols(cls, OrderedDict([ + ("nfaces_re_h", lfric_qr_xy_class), + ("adjacent_face", symbols.DataSymbol)])) + symbol = cls._info.lookup("adjacent_face") + assert symbol.is_array + assert len(symbol.datatype.shape) == 1 + assert symbol.datatype.shape[0].upper.symbol.name == "nfaces_re_h" + + # Exception + monkeypatch.setattr(meta_mesh_arg, "_mesh", "invalid") + with pytest.raises(InternalError) as info: + _ = call_method("_mesh_properties", [meta_mesh_arg], metadata=metadata) + assert ("Unexpected mesh property 'invalid' found. Expected " + "'adjacent_face'." in str(info.value)) + + +def test_fs_common(): + ''' Test _fs_common method. ''' + function_space = "w3" + symbol_name = lfric.FormalKernelArgsFromMetadata._ndf_name(function_space) + cls = check_single_symbol( + "_fs_common", "NumberOfDofsDataSymbol", symbol_name, function_space, + check_unchanged=True) + + +def test_fs_compulsory_field(): + ''' Test _fs_compulsory_field method. ''' + function_space = "w3" + cls = call_method("_fs_compulsory_field", function_space) + check_fs_compulsory_field(cls, function_space) + + +def check_fs_compulsory_field(cls, function_space): + ''' xxx ''' + undf_name = lfric.FormalKernelArgsFromMetadata._undf_name(function_space) + dofmap_name = lfric.FormalKernelArgsFromMetadata._dofmap_name( + function_space) + # Check that undf and dofmap symbols are added to the symbol table + # and to the argument list. + undf_class = lfric.LFRicTypes("NumberOfUniqueDofsDataSymbol") + check_arg_symbols(cls, OrderedDict( + [(undf_name, undf_class), (dofmap_name, symbols.DataSymbol)])) + # Check that dofmap is an array with the expected extent. + dofmap_symbol = cls._info.lookup(dofmap_name) + assert dofmap_symbol.is_array + assert len(dofmap_symbol.datatype.shape) == 1 + ndf_name = lfric.FormalKernelArgsFromMetadata._ndf_name(function_space) + assert dofmap_symbol.datatype.shape[0].upper.symbol.name == ndf_name + # Check that the method works if undf has already been added to + # the symbol table. + # Remove dofmap from the symbol table and remove both symbols from + # the argument list. This should leave the symbol table containing + # just undf. + cls._info._argument_list = [] + norm_name = cls._info._normalize(dofmap_name) + cls._info._symbols.pop(norm_name) + # remove dofmap tag if there is one + for tag, tagged_symbol in list(cls._info._tags.items()): + if dofmap_symbol is tagged_symbol: + del cls._info._tags[tag] + + # Call the method again and check that the undf symbol does not + # change and that a dofmap symbol is added to the symbol table. + undf_symbol = cls._info.lookup(undf_name) + cls._fs_compulsory_field(function_space) + check_arg_symbols(cls, OrderedDict( + [(undf_name, undf_class), (dofmap_name, symbols.DataSymbol)])) + assert cls._info.lookup(undf_name) is undf_symbol + + +def test_fs_intergrid(): + ''' Test _fs_intergrid method. ''' + # gh_fine + function_space = "w3" + intergrid_meta_arg = lfric.kernel.InterGridArgMetadata( + "GH_REAL", "GH_WRITE", function_space, "GH_FINE") + cls = call_method("_fs_intergrid", intergrid_meta_arg) + ndf_name = lfric.FormalKernelArgsFromMetadata._ndf_name(function_space) + undf_name = lfric.FormalKernelArgsFromMetadata._undf_name(function_space) + fullmap_name = lfric.FormalKernelArgsFromMetadata._fullmap_name( + function_space) + ndf_class = lfric.LFRicTypes("NumberOfDofsDataSymbol") + undf_class = lfric.LFRicTypes("NumberOfUniqueDofsDataSymbol") + check_arg_symbols(cls, OrderedDict( + [(ndf_name, ndf_class), (undf_name, undf_class), + (fullmap_name, symbols.DataSymbol)])) + # Check that fullmap is an array with the expected extent. + fullmap_symbol = cls._info.lookup(fullmap_name) + assert fullmap_symbol.is_array + assert len(fullmap_symbol.datatype.shape) == 2 + ndf_name = lfric.FormalKernelArgsFromMetadata._ndf_name(function_space) + assert fullmap_symbol.datatype.shape[0].upper.symbol.name == ndf_name + assert fullmap_symbol.datatype.shape[1].upper.symbol.name == "ncell_f" + + # if undf and ndf are already declared + undf_symbol = cls._info.lookup(undf_name) + ndf_symbol = cls._info.lookup(ndf_name) + from psyclone.psyir.symbols import SymbolTable + symbol_table = SymbolTable() + symbol_table.add(undf_symbol, tag=undf_name) + symbol_table.add(ndf_symbol, tag=ndf_name) + cls._info = symbol_table + cls._fs_intergrid(intergrid_meta_arg) + check_arg_symbols(cls, OrderedDict( + [(ndf_name, ndf_class), (undf_name, undf_class), + (fullmap_name, symbols.DataSymbol)])) + assert cls._info.lookup(undf_name) is undf_symbol + assert cls._info.lookup(ndf_name) is ndf_symbol + + # gh_coarse + intergrid_meta_arg = lfric.kernel.InterGridArgMetadata( + "GH_REAL", "GH_WRITE", function_space, "GH_COARSE") + cls = call_method("_fs_intergrid", intergrid_meta_arg) + check_fs_compulsory_field(cls, function_space) + + +def test_basis_or_diff_basis_dimension(): + ''' TODO ''' + pass + + +def test_basis_dimension(): + '''Test the _basis_dimension utility method. Test with one example of + each option (returning 1, returning 3, or raising an + exception. Also test that function_space values can be lower or upper + case. + + ''' + cls = lfric.FormalKernelArgsFromMetadata + assert cls._basis_dimension("W0") == 1 + assert cls._basis_dimension("w1") == 3 + with pytest.raises(ValueError) as info: + cls._basis_dimension("invalid") + assert ("Unexpected function space value 'invalid' found in " + "basis_dimension. Expected one of ['w0', 'w2trace', 'w2htrace', " + "'w2vtrace', 'w3', 'wtheta', 'wchi', 'w1', 'w2', 'w2h', 'w2v', " + "'w2broken', 'any_w2']." in str(info.value)) + + +def test_diff_basis_dimension(): + '''Test the _diff_basis_dimension utility method. Test with one + example of each case (returning 1, returning 3, or raising an + exception. Also test that function_space values can be upper or + lower case. + + ''' + cls = lfric.FormalKernelArgsFromMetadata + assert cls._diff_basis_dimension("W2") == 1 + assert cls._diff_basis_dimension("w0") == 3 + with pytest.raises(ValueError) as info: + cls._diff_basis_dimension("invalid") + assert ("Unexpected function space value 'invalid' found in " + "diff_basis_dimension. Expected one of ['w2', 'w2h', 'w2v', " + "'w2broken', 'any_w2', 'w0', 'w1', 'w2trace', 'w2htrace', " + "'w2vtrace', 'w3', 'wtheta', 'wchi']." in str(info.value)) + + +def test_basis_or_diff_basis(): + ''' TODO ''' + # gh_quadrature_* + # name = "_"_. + # gh_quadrature_xyoz + # (dimension, number_of_dofs, np_xy, np_z). + # gh_quadrature_face + # (dimension, number_of_dofs, np_xyz, nfaces) + # gh_quadrature_edge + # (dimension, number_of_dofs, np_xyz, nedges) + # gh_evaluator + # name = "_""_on_". + # (dimension, number_of_dofs, ndf_) + + +def test_basis(): + ''' Test _basis method. ''' + function_space = "w3" + field_meta_arg = lfric.kernel.FieldArgMetadata("GH_REAL", "GH_WRITE", "W3") + meta_funcs_arg = lfric.kernel.MetaFuncsArgMetadata(function_space, + basis_function=True) + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[field_meta_arg], + meta_funcs=[meta_funcs_arg], shapes=["gh_quadrature_xyoz"]) + metadata.validate() + cls = call_method("_basis", function_space, metadata=metadata) + check_arg_symbols(cls, OrderedDict( + [("np_xy_qr_xyoz", symbols.DataSymbol), + ("np_z_qr_xyoz", symbols.DataSymbol), + ("basis_w3_qr_xyoz", symbols.DataSymbol)])) + # Check that basis is an array with the expected extent. + basis_symbol = cls._info.lookup("basis_w3_qr_xyoz") + assert basis_symbol.is_array + assert len(basis_symbol.datatype.shape) == 4 + ndf_name = lfric.FormalKernelArgsFromMetadata._ndf_name(function_space) + assert basis_symbol.datatype.shape[0].upper.value == "1" + assert basis_symbol.datatype.shape[1].upper.symbol.name == ndf_name + assert basis_symbol.datatype.shape[2].upper.symbol.name == "np_xy_qr_xyoz" + assert basis_symbol.datatype.shape[3].upper.symbol.name == "np_z_qr_xyoz" + + +def test_diff_basis(): + ''' Test _diff_basis method. ''' + function_space = "w3" + field_meta_arg = lfric.kernel.FieldArgMetadata("GH_REAL", "GH_WRITE", "W3") + meta_funcs_arg = lfric.kernel.MetaFuncsArgMetadata( + function_space, diff_basis_function=True) + metadata = lfric.kernel.LFRicKernelMetadata( + operates_on="cell_column", meta_args=[field_meta_arg], + meta_funcs=[meta_funcs_arg], shapes=["gh_quadrature_xyoz"]) + metadata.validate() + cls = call_method("_diff_basis", function_space, metadata=metadata) + check_arg_symbols(cls, OrderedDict( + [("np_xy_qr_xyoz", symbols.DataSymbol), + ("np_z_qr_xyoz", symbols.DataSymbol), + ("diff_basis_w3_qr_xyoz", symbols.DataSymbol)])) + # Check that diff basis is an array with the expected extent. + diff_basis_symbol = cls._info.lookup("diff_basis_w3_qr_xyoz") + assert diff_basis_symbol.is_array + assert len(diff_basis_symbol.datatype.shape) == 4 + ndf_name = lfric.FormalKernelArgsFromMetadata._ndf_name(function_space) + assert diff_basis_symbol.datatype.shape[0].upper.value == "3" + assert diff_basis_symbol.datatype.shape[1].upper.symbol.name == ndf_name + assert (diff_basis_symbol.datatype.shape[2].upper.symbol.name == + "np_xy_qr_xyoz") + assert (diff_basis_symbol.datatype.shape[3].upper.symbol.name == + "np_z_qr_xyoz") diff --git a/src/psyclone/tests/domain/lfric/kernel/lfric_kernel_metadata_test.py b/src/psyclone/tests/domain/lfric/kernel/lfric_kernel_metadata_test.py index cfa2493f96..a41421894d 100644 --- a/src/psyclone/tests/domain/lfric/kernel/lfric_kernel_metadata_test.py +++ b/src/psyclone/tests/domain/lfric/kernel/lfric_kernel_metadata_test.py @@ -915,10 +915,36 @@ def test_validate_intergrid_kernel(): # or gh_evaluator_targets should exist, but this is not yet checked. METADATA = ( "type, extends(kernel_type) :: testkern_type\n" - " type(arg_type), dimension(7) :: meta_args = &\n" + " type(arg_type), dimension(4) :: meta_args = &\n" " (/ arg_type(gh_scalar, gh_real, gh_read), &\n" " arg_type(gh_field, gh_real, gh_inc, w1), &\n" " arg_type(gh_field*3, gh_real, gh_read, w2), &\n" + " arg_type(gh_operator, gh_real, gh_read, w2, w3) &\n" + " /)\n" + " type(func_type), dimension(2) :: meta_funcs = &\n" + " (/ func_type(w1, gh_basis), &\n" + " func_type(w2, gh_basis, gh_diff_basis) &\n" + " /)\n" + " type(reference_element_data_type), dimension(2) :: &\n" + " meta_reference_element = &\n" + " (/ reference_element_data_type(normals_to_horizontal_faces), &\n" + " reference_element_data_type(normals_to_vertical_faces) &\n" + " /)\n" + " type(mesh_data_type) :: meta_mesh(1) = &\n" + " (/ mesh_data_type(adjacent_face) /)\n" + " integer :: gh_shape = gh_quadrature_XYoZ\n" + " integer :: gh_evaluator_targets(2) = (/ w0, w3 /)\n" + " integer :: operates_on = cell_column\n" + " contains\n" + " procedure, nopass :: code => testkern_code\n" + "end type testkern_type\n") + +INTERGRID_METADATA = ( + "type, extends(kernel_type) :: testkern_type\n" + " type(arg_type), dimension(6) :: meta_args = &\n" + " (/ &\n" # arg_type(gh_scalar, gh_real, gh_read), &\n" + " arg_type(gh_field, gh_real, gh_inc, w1), &\n" + " arg_type(gh_field*3, gh_real, gh_read, w2), &\n" " arg_type(gh_field, gh_real, gh_read, w2, " "mesh_arg=gh_coarse), &\n" " arg_type(gh_field*3, gh_real, gh_read, w2, " @@ -968,14 +994,11 @@ def test_create_from_psyir(fortran_reader): assert metadata.evaluator_targets == ["w0", "w3"] assert isinstance(metadata.meta_args, list) - assert len(metadata.meta_args) == 7 + assert len(metadata.meta_args) == 4 assert isinstance(metadata.meta_args[0], ScalarArgMetadata) assert isinstance(metadata.meta_args[1], FieldArgMetadata) assert isinstance(metadata.meta_args[2], FieldVectorArgMetadata) - assert isinstance(metadata.meta_args[3], InterGridArgMetadata) - assert isinstance(metadata.meta_args[4], InterGridVectorArgMetadata) - assert isinstance(metadata.meta_args[5], OperatorArgMetadata) - assert isinstance(metadata.meta_args[6], ColumnwiseOperatorArgMetadata) + assert isinstance(metadata.meta_args[3], OperatorArgMetadata) assert isinstance(metadata.meta_funcs, list) assert isinstance(metadata.meta_funcs[0], MetaFuncsArgMetadata) @@ -1027,14 +1050,14 @@ def test_create_from_fparser2(procedure_format): assert metadata.evaluator_targets == ["w0", "w3"] assert isinstance(metadata.meta_args, list) - assert len(metadata.meta_args) == 7 + assert len(metadata.meta_args) == 4 assert isinstance(metadata.meta_args[0], ScalarArgMetadata) assert isinstance(metadata.meta_args[1], FieldArgMetadata) assert isinstance(metadata.meta_args[2], FieldVectorArgMetadata) - assert isinstance(metadata.meta_args[3], InterGridArgMetadata) - assert isinstance(metadata.meta_args[4], InterGridVectorArgMetadata) - assert isinstance(metadata.meta_args[5], OperatorArgMetadata) - assert isinstance(metadata.meta_args[6], ColumnwiseOperatorArgMetadata) + #assert isinstance(metadata.meta_args[3], InterGridArgMetadata) + #assert isinstance(metadata.meta_args[4], InterGridVectorArgMetadata) + assert isinstance(metadata.meta_args[3], OperatorArgMetadata) + #assert isinstance(metadata.meta_args[6], ColumnwiseOperatorArgMetadata) assert isinstance(metadata.meta_funcs, list) assert isinstance(metadata.meta_funcs[0], MetaFuncsArgMetadata) assert isinstance(metadata.meta_funcs[1], MetaFuncsArgMetadata) @@ -1060,8 +1083,9 @@ def test_create_from_fparser2_no_optional(): ''' metadata = ( "type, extends(kernel_type) :: testkern_type\n" - " type(arg_type), dimension(1) :: meta_args = &\n" - " (/ arg_type(gh_scalar, gh_real, gh_read) /)\n" + " type(arg_type), dimension(2) :: meta_args = &\n" + " (/ arg_type(gh_scalar, gh_real, gh_read), &\n" + " arg_type(gh_field, gh_real, gh_write, w3) /)\n" " integer :: operates_on = cell_column\n" " contains\n" " procedure, nopass :: code => testkern_code\n" @@ -1108,7 +1132,7 @@ def test_create_from_fparser2_error(): _ = LFRicKernelMetadata.create_from_fparser2(fparser2_tree) assert ("The metadata type declaration should extend kernel_type, but " "found 'TYPE :: testkern_type' in TYPE :: " - "testkern_type\n TYPE(arg_type), DIMENSION(7)" in str(info.value)) + "testkern_type\n TYPE(arg_type), DIMENSION(4)" in str(info.value)) # metadata type extends incorrect type fparser2_tree = LFRicKernelMetadata.create_fparser2(METADATA.replace( @@ -1118,7 +1142,7 @@ def test_create_from_fparser2_error(): assert ("The metadata type declaration should extend kernel_type, but " "found 'TYPE, EXTENDS(invalid_type) :: testkern_type' in TYPE, " "EXTENDS(invalid_type) :: testkern_type\n TYPE(arg_type), " - "DIMENSION(7)" in str(info.value)) + "DIMENSION(4)" in str(info.value)) def test_lower_to_psyir(): @@ -1215,14 +1239,11 @@ def test_fortran_string(): result = metadata.fortran_string() expected = ( "TYPE, PUBLIC, EXTENDS(kernel_type) :: testkern_type\n" - " type(ARG_TYPE) :: META_ARGS(7) = (/ &\n" + " type(ARG_TYPE) :: META_ARGS(4) = (/ &\n" " arg_type(gh_scalar, gh_real, gh_read), &\n" " arg_type(gh_field, gh_real, gh_inc, w1), &\n" " arg_type(gh_field*3, gh_real, gh_read, w2), &\n" - " arg_type(gh_field, gh_real, gh_read, w2, mesh_arg=gh_coarse), &\n" - " arg_type(gh_field*3, gh_real, gh_read, w2, mesh_arg=gh_fine), &\n" - " arg_type(gh_operator, gh_real, gh_read, w2, w3), &\n" - " arg_type(gh_columnwise_operator, gh_real, gh_read, w3, w0)/)\n" + " arg_type(gh_operator, gh_real, gh_read, w2, w3)/)\n" " type(FUNC_TYPE) :: META_FUNCS(2) = (/ &\n" " func_type(w1, gh_basis), &\n" " func_type(w2, gh_basis, gh_diff_basis)/)\n" @@ -1252,14 +1273,11 @@ def test_fortran_string_no_procedure(): result = metadata.fortran_string() expected = ( "TYPE, PUBLIC, EXTENDS(kernel_type) :: testkern_type\n" - " type(ARG_TYPE) :: META_ARGS(7) = (/ &\n" + " type(ARG_TYPE) :: META_ARGS(4) = (/ &\n" " arg_type(gh_scalar, gh_real, gh_read), &\n" " arg_type(gh_field, gh_real, gh_inc, w1), &\n" " arg_type(gh_field*3, gh_real, gh_read, w2), &\n" - " arg_type(gh_field, gh_real, gh_read, w2, mesh_arg=gh_coarse), &\n" - " arg_type(gh_field*3, gh_real, gh_read, w2, mesh_arg=gh_fine), &\n" - " arg_type(gh_operator, gh_real, gh_read, w2, w3), &\n" - " arg_type(gh_columnwise_operator, gh_real, gh_read, w3, w0)/)\n" + " arg_type(gh_operator, gh_real, gh_read, w2, w3)/)\n" " type(FUNC_TYPE) :: META_FUNCS(2) = (/ &\n" " func_type(w1, gh_basis), &\n" " func_type(w2, gh_basis, gh_diff_basis)/)\n" diff --git a/src/psyclone/tests/domain/lfric/lfric_dofmaps_test.py b/src/psyclone/tests/domain/lfric/lfric_dofmaps_test.py index 3c20cc5eb5..56e1814a7c 100644 --- a/src/psyclone/tests/domain/lfric/lfric_dofmaps_test.py +++ b/src/psyclone/tests/domain/lfric/lfric_dofmaps_test.py @@ -180,38 +180,36 @@ def test_unique_fs_comments(): assert output in code -def test_stub_decl_dofmaps(): +def test_stub_decl_dofmaps(fortran_writer): ''' Check that LFRicDofmaps generates the expected declarations in the stub. ''' - result = generate(os.path.join(BASE_PATH, - "columnwise_op_asm_kernel_mod.F90"), - api=TEST_API) + psyir = generate(os.path.join(BASE_PATH, + "columnwise_op_asm_kernel_mod.F90"), + api=TEST_API) + result = fortran_writer(psyir) + assert "integer(kind=i_def), intent(in) :: nrow_cma_op_2" in result + assert "integer(kind=i_def), intent(in) :: ncol_cma_op_2" in result - assert ("INTEGER(KIND=i_def), intent(in) :: cma_op_2_nrow, cma_op_2_ncol" - in str(result)) - -def test_lfricdofmaps_stub_gen(): +def test_lfricdofmaps_stub_gen(fortran_writer): ''' Test the kernel-stub generator for a CMA apply kernel. This has two fields and one CMA operator as arguments. ''' - result = generate(os.path.join(BASE_PATH, - "columnwise_op_app_kernel_mod.F90"), - api=TEST_API) - + psyir = generate(os.path.join(BASE_PATH, + "columnwise_op_app_kernel_mod.F90"), + api=TEST_API) + output = fortran_writer(psyir) expected = ( - " SUBROUTINE columnwise_op_app_kernel_code(cell, ncell_2d, " - "field_1_aspc1_field_1, field_2_aspc2_field_2, cma_op_3, " - "cma_op_3_nrow, cma_op_3_ncol, cma_op_3_bandwidth, cma_op_3_alpha, " - "cma_op_3_beta, cma_op_3_gamma_m, cma_op_3_gamma_p, " - "ndf_aspc1_field_1, undf_aspc1_field_1, map_aspc1_field_1, " - "cma_indirection_map_aspc1_field_1, ndf_aspc2_field_2, " - "undf_aspc2_field_2, map_aspc2_field_2, " - "cma_indirection_map_aspc2_field_2)\n" + "subroutine columnwise_op_app_kernel_code(cell, ncell_2d, " + "rfield_1_aspc1, rfield_2_aspc2, cma_op_3, " + "nrow_cma_op_3, ncol_cma_op_3, bandwidth_cma_op_3, alpha_cma_op_3, " + "beta_cma_op_3, gamma_m_cma_op_3, gamma_p_cma_op_3, " + "ndf_aspc1, undf_aspc1, dofmap_aspc1, cma_dofmap_aspc1, " + "ndf_aspc2, undf_aspc2, dofmap_aspc2, cma_dofmap_aspc2)\n" ) - assert expected in str(result) + assert expected in output diff --git a/src/psyclone/tests/domain/lfric/lfric_scalar_stubgen_test.py b/src/psyclone/tests/domain/lfric/lfric_scalar_stubgen_test.py index 8bf714ce90..597b7f11eb 100644 --- a/src/psyclone/tests/domain/lfric/lfric_scalar_stubgen_test.py +++ b/src/psyclone/tests/domain/lfric/lfric_scalar_stubgen_test.py @@ -41,7 +41,6 @@ LFRic scalar arguments. ''' -from __future__ import absolute_import, print_function import os import pytest @@ -84,46 +83,42 @@ def test_lfricscalars_stub_err(): f"{const.VALID_SCALAR_DATA_TYPES}." in str(err.value)) -def test_stub_generate_with_scalars(): +def test_stub_generate_with_scalars(fortran_writer): ''' Check that the stub generate produces the expected output when the kernel has scalar arguments. ''' - result = generate( + psyir = generate( os.path.join(BASE_PATH, "testkern_three_scalars_mod.f90"), api=TEST_API) - + result = fortran_writer(psyir) expected = ( - " MODULE testkern_three_scalars_mod\n" - " IMPLICIT NONE\n" - " CONTAINS\n" - " SUBROUTINE testkern_three_scalars_code(nlayers, rscalar_1, " - "field_2_w1, field_3_w2, field_4_w2, field_5_w3, lscalar_6, " - "iscalar_7, ndf_w1, undf_w1, map_w1, ndf_w2, undf_w2, map_w2, " - "ndf_w3, undf_w3, map_w3)\n" - " USE constants_mod\n" - " IMPLICIT NONE\n" - " INTEGER(KIND=i_def), intent(in) :: nlayers\n" - " INTEGER(KIND=i_def), intent(in) :: ndf_w1\n" - " INTEGER(KIND=i_def), intent(in), dimension(ndf_w1) :: map_w1\n" - " INTEGER(KIND=i_def), intent(in) :: ndf_w2\n" - " INTEGER(KIND=i_def), intent(in), dimension(ndf_w2) :: map_w2\n" - " INTEGER(KIND=i_def), intent(in) :: ndf_w3\n" - " INTEGER(KIND=i_def), intent(in), dimension(ndf_w3) :: map_w3\n" - " INTEGER(KIND=i_def), intent(in) :: undf_w1, undf_w2, undf_w3\n" - " REAL(KIND=r_def), intent(in) :: rscalar_1\n" - " INTEGER(KIND=i_def), intent(in) :: iscalar_7\n" - " LOGICAL(KIND=l_def), intent(in) :: lscalar_6\n" - " REAL(KIND=r_def), intent(inout), dimension(undf_w1) :: " - "field_2_w1\n" - " REAL(KIND=r_def), intent(in), dimension(undf_w2) :: " - "field_3_w2\n" - " REAL(KIND=r_def), intent(in), dimension(undf_w2) :: " - "field_4_w2\n" - " REAL(KIND=r_def), intent(in), dimension(undf_w3) :: " - "field_5_w3\n" - " END SUBROUTINE testkern_three_scalars_code\n" - " END MODULE testkern_three_scalars_mod") + "SUBROUTINE testkern_three_scalars_code(nlayers, rscalar_1, " + "rfield_2_w1, rfield_3_w2, rfield_4_w2, rfield_5_w3, lscalar_6, " + "iscalar_7, ndf_w1, undf_w1, dofmap_w1, ndf_w2, undf_w2, dofmap_w2, " + "ndf_w3, undf_w3, dofmap_w3)\n").lower() + assert expected in result + expected2 = ( + " USE constants_mod, only : i_def, l_def, r_def\n" + " INTEGER(KIND=i_def), intent(in) :: nlayers\n" + " REAL(KIND=r_def), intent(in) :: rscalar_1\n" + " INTEGER(KIND=i_def), intent(in) :: undf_w1\n" + " REAL(KIND=r_def), dimension(undf_w1), intent(inout) :: " + "rfield_2_w1\n" + " INTEGER(KIND=i_def), intent(in) :: undf_w2\n" + " REAL(KIND=r_def), dimension(undf_w2), intent(in) :: rfield_3_w2\n" + " REAL(KIND=r_def), dimension(undf_w2), intent(in) :: rfield_4_w2\n" + " INTEGER(KIND=i_def), intent(in) :: undf_w3\n" + " REAL(KIND=r_def), dimension(undf_w3), intent(in) :: rfield_5_w3\n" + " LOGICAL(KIND=l_def), intent(in) :: lscalar_6\n" + " INTEGER(KIND=i_def), intent(in) :: iscalar_7\n" + " INTEGER(KIND=i_def), intent(in) :: ndf_w1\n" + " INTEGER(KIND=i_def), dimension(ndf_w1), intent(in) :: dofmap_w1\n" + " INTEGER(KIND=i_def), intent(in) :: ndf_w2\n" + " INTEGER(KIND=i_def), dimension(ndf_w2), intent(in) :: dofmap_w2\n" + " INTEGER(KIND=i_def), intent(in) :: ndf_w3\n" + " INTEGER(KIND=i_def), dimension(ndf_w3), intent(in) :: dofmap_w3\n" + ).lower() - assert expected in str(result) + assert expected2 in result def test_stub_generate_with_scalar_sums_err(): @@ -133,7 +128,6 @@ def test_stub_generate_with_scalar_sums_err(): _ = generate( os.path.join(BASE_PATH, "testkern_simple_with_reduction_mod.f90"), api=TEST_API) - assert ( - "A user-supplied LFRic kernel must not write/update a scalar " - "argument but kernel 'simple_with_reduction_type' has a scalar " - "argument with 'gh_sum' access." in str(err.value)) + assert ("Scalar arguments to general-purpose kernels with 'operates_on == " + "cell_column' must be read-only but found 'gh_real' scalar with " + "'gh_sum' access in" in str(err.value)) diff --git a/src/psyclone/tests/dynamo0p3_cma_test.py b/src/psyclone/tests/dynamo0p3_cma_test.py index ce2a3b3e41..c3cb672af4 100644 --- a/src/psyclone/tests/dynamo0p3_cma_test.py +++ b/src/psyclone/tests/dynamo0p3_cma_test.py @@ -44,6 +44,8 @@ from fparser import api as fpapi from psyclone.tests.lfric_build import LFRicBuild +from psyclone.tests.utilities import print_diffs + from psyclone.configuration import Config from psyclone.core.access_type import AccessType from psyclone.domain.lfric import (LFRicArgDescriptor, LFRicConstants, @@ -53,6 +55,7 @@ from psyclone.parse.algorithm import parse from psyclone.parse.utils import ParseError from psyclone.psyGen import PSyFactory +from psyclone.psyir import nodes, symbols # Constants BASE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), @@ -1306,91 +1309,106 @@ def test_cma_multi_kernel(tmpdir, dist_mem): # Tests for the kernel-stub generator -def test_cma_asm_stub_gen(): +def test_cma_asm_stub_gen(fortran_writer): ''' Test the kernel-stub generator for CMA operator assembly. ''' path = os.path.join(BASE_PATH, "columnwise_op_asm_kernel_mod.F90") - result = generate(path, api=TEST_API) - - expected = ( - " MODULE columnwise_op_asm_kernel_mod\n" - " IMPLICIT NONE\n" - " CONTAINS\n" - " SUBROUTINE columnwise_op_asm_kernel_code(cell, nlayers, " - "ncell_2d, op_1_ncell_3d, op_1, cma_op_2, cma_op_2_nrow, " - "cma_op_2_ncol, cma_op_2_bandwidth, cma_op_2_alpha, cma_op_2_beta, " - "cma_op_2_gamma_m, cma_op_2_gamma_p, ndf_adspc1_op_1, " - "cbanded_map_adspc1_op_1, ndf_adspc2_op_1, cbanded_map_adspc2_op_1)\n" - " USE constants_mod\n" - " IMPLICIT NONE\n" - " INTEGER(KIND=i_def), intent(in) :: nlayers\n" - " INTEGER(KIND=i_def), intent(in) :: ndf_adspc1_op_1\n" - " INTEGER(KIND=i_def), intent(in), dimension(" - "ndf_adspc1_op_1,nlayers) :: cbanded_map_adspc1_op_1\n" - " INTEGER(KIND=i_def), intent(in) :: ndf_adspc2_op_1\n" - " INTEGER(KIND=i_def), intent(in), dimension(" - "ndf_adspc2_op_1,nlayers) :: cbanded_map_adspc2_op_1\n" - " INTEGER(KIND=i_def), intent(in) :: cell, ncell_2d\n" - " INTEGER(KIND=i_def), intent(in) :: cma_op_2_nrow, " - "cma_op_2_ncol, cma_op_2_bandwidth, cma_op_2_alpha, cma_op_2_beta, " - "cma_op_2_gamma_m, cma_op_2_gamma_p\n" - " REAL(KIND=r_solver), intent(inout), dimension(" - "cma_op_2_bandwidth,cma_op_2_nrow,ncell_2d) :: cma_op_2\n" - " INTEGER(KIND=i_def), intent(in) :: op_1_ncell_3d\n" - " REAL(KIND=r_def), intent(in), dimension(op_1_ncell_3d," - "ndf_adspc1_op_1,ndf_adspc2_op_1) :: op_1\n" - " END SUBROUTINE columnwise_op_asm_kernel_code\n" - " END MODULE columnwise_op_asm_kernel_mod") - assert expected in str(result) - - -def test_cma_asm_with_field_stub_gen(): + psyir = generate(path, api=TEST_API) + result = fortran_writer(psyir) + expected1 = ( + "subroutine columnwise_op_asm_kernel_code(cell, nlayers, " + "ncell_2d, op_1_ncell_3d, op_1, cma_op_2, nrow_cma_op_2, " + "ncol_cma_op_2, bandwidth_cma_op_2, alpha_cma_op_2, beta_cma_op_2, " + "gamma_m_cma_op_2, gamma_p_cma_op_2, ndf_adspc1, " + "cbanded_map_adspc1_cma_op_2, ndf_adspc2, " + "cbanded_map_adspc2_cma_op_2)\n") + assert expected1 in result + expected2 = ( + " use constants_mod, only : i_def, r_def\n" + " integer(kind=i_def), intent(in) :: cell\n" + " integer(kind=i_def), intent(in) :: nlayers\n" + " integer(kind=i_def), intent(in) :: ncell_2d\n" + " integer(kind=i_def), intent(in) :: op_1_ncell_3d\n" + " integer(kind=i_def), intent(in) :: ndf_adspc1\n" + " integer(kind=i_def), intent(in) :: ndf_adspc2\n" + " real(kind=r_def), dimension(op_1_ncell_3d, ndf_adspc1," + "ndf_adspc2), intent(in) :: op_1\n" + " integer(kind=i_def), intent(in) :: bandwidth_cma_op_2\n" + " integer(kind=i_def), intent(in) :: nrow_cma_op_2\n" + # Should be r_solver TODO + " real(kind=r_def), dimension(bandwidth_cma_op_2," + "nrow_cma_op_2,ncell_2d), intent(out) :: cma_op_2\n" + " integer(kind=i_def), intent(in) :: ncol_cma_op_2\n" + " integer(kind=i_def), intent(in) :: alpha_cma_op_2\n" + " integer(kind=i_def), intent(in) :: beta_cma_op_2\n" + " integer(kind=i_def), intent(in) :: gamma_m_cma_op_2\n" + " integer(kind=i_def), intent(in) :: gamma_p_cma_op_2\n" + " integer(kind=i_def), dimension(ndf_adspc1,nlayers), intent(in) " + ":: cbanded_map_adspc1_cma_op_2\n" + " integer(kind=i_def), dimension(ndf_adspc2,nlayers), intent(in) " + ":: cbanded_map_adspc2_cma_op_2\n\n\n" + " end subroutine columnwise_op_asm_kernel_code\n") + if expected2 not in result: + print_diffs(expected2, result) + assert 0 + + +def test_cma_asm_with_field_stub_gen(fortran_writer): ''' Test the kernel-stub generator for CMA operator assembly when a field is involved. ''' - result = generate(os.path.join(BASE_PATH, - "columnwise_op_asm_field_kernel_mod.F90"), - api=TEST_API) - + psyir = generate(os.path.join(BASE_PATH, + "columnwise_op_asm_field_kernel_mod.F90"), + api=TEST_API) + routine = psyir.walk(nodes.Routine)[0] + assert routine.name == "columnwise_op_asm_field_kernel_code" + table = routine.symbol_table + assert [sym.name for sym in table.argument_list] == [ + "cell", "nlayers", "ncell_2d", "field_1_aspc1_field_1", + "op_2_ncell_3d", "op_2", "cma_op_3", "nrow_cma_op_3", "ncol_cma_op_3", + "bandwidth_cma_op_3", "alpha_cma_op_3", "beta_cma_op_3", + "gamma_m_cma_op_3", "gamma_p_cma_op_3", + "ndf_aspc1_field_1", "undf_aspc1_field_1", "map_aspc1_field_1", + "cbanded_map_aspc1_field_1", "ndf_aspc2_op_2", "cbanded_map_aspc2_op_2"] + cmod = table.lookup("constants_mod") + assert isinstance(cmd, symbols.ContainerSymbol) expected = ( - " MODULE columnwise_op_asm_field_kernel_mod\n" - " IMPLICIT NONE\n" - " CONTAINS\n" - " SUBROUTINE columnwise_op_asm_field_kernel_code(cell, nlayers, " - "ncell_2d, field_1_aspc1_field_1, op_2_ncell_3d, op_2, cma_op_3, " - "cma_op_3_nrow, cma_op_3_ncol, cma_op_3_bandwidth, cma_op_3_alpha, " - "cma_op_3_beta, cma_op_3_gamma_m, cma_op_3_gamma_p, " - "ndf_aspc1_field_1, undf_aspc1_field_1, map_aspc1_field_1, " - "cbanded_map_aspc1_field_1, ndf_aspc2_op_2, cbanded_map_aspc2_op_2)\n" - " USE constants_mod\n" - " IMPLICIT NONE\n" - " INTEGER(KIND=i_def), intent(in) :: nlayers\n" - " INTEGER(KIND=i_def), intent(in) :: ndf_aspc1_field_1\n" - " INTEGER(KIND=i_def), intent(in), " - "dimension(ndf_aspc1_field_1) :: map_aspc1_field_1\n" - " INTEGER(KIND=i_def), intent(in), " - "dimension(ndf_aspc1_field_1,nlayers) :: cbanded_map_aspc1_field_1\n" - " INTEGER(KIND=i_def), intent(in) :: ndf_aspc2_op_2\n" - " INTEGER(KIND=i_def), intent(in), " - "dimension(ndf_aspc2_op_2,nlayers) :: cbanded_map_aspc2_op_2\n" - " INTEGER(KIND=i_def), intent(in) :: undf_aspc1_field_1\n" - " INTEGER(KIND=i_def), intent(in) :: cell, ncell_2d\n" - " INTEGER(KIND=i_def), intent(in) :: cma_op_3_nrow, " - "cma_op_3_ncol, cma_op_3_bandwidth, cma_op_3_alpha, cma_op_3_beta, " - "cma_op_3_gamma_m, cma_op_3_gamma_p\n" - " REAL(KIND=r_solver), intent(inout), dimension(" - "cma_op_3_bandwidth,cma_op_3_nrow,ncell_2d) :: cma_op_3\n" - " REAL(KIND=r_def), intent(in), dimension(undf_aspc1_field_1) :: " - "field_1_aspc1_field_1\n" - " INTEGER(KIND=i_def), intent(in) :: op_2_ncell_3d\n" - " REAL(KIND=r_def), intent(in), dimension(" - "op_2_ncell_3d,ndf_aspc1_field_1,ndf_aspc2_op_2) :: op_2\n" - " END SUBROUTINE columnwise_op_asm_field_kernel_code\n" - " END MODULE columnwise_op_asm_field_kernel_mod") - assert expected in str(result) + " use constants_mod, only : i_def, r_def\n" + " integer(kind=i_def), intent(in) :: cell\n" + " integer(kind=i_def), intent(in) :: nlayers\n" + " integer(kind=i_def), intent(in) :: ncell_2d\n" + " integer(kind=i_def), intent(in) :: undf_aspc1\n" + " real(kind=r_def), dimension(undf_aspc1), intent(in) :: " + "rfield_1_aspc1\n" + " integer(kind=i_def), intent(in) :: op_2_ncell_3d\n" + " integer(kind=i_def), intent(in) :: ndf_aspc1\n" + " integer(kind=i_def), intent(in) :: ndf_aspc2\n" + " real(kind=r_def), dimension(op_2_ncell_3d,ndf_aspc1,ndf_aspc2), " + "intent(in) :: op_2\n" + " integer(kind=i_def), intent(in) :: bandwidth_cma_op_3\n" + " integer(kind=i_def), intent(in) :: nrow_cma_op_3\n" + # TODO was r_solver and (inout) before? + " real(kind=r_def), dimension(" + "bandwidth_cma_op_3,nrow_cma_op_3,ncell_2d), intent(out) :: cma_op_3\n" + " integer(kind=i_def), intent(in) :: ncol_cma_op_3\n" + " integer(kind=i_def), intent(in) :: alpha_cma_op_3\n" + " integer(kind=i_def), intent(in) :: beta_cma_op_3\n" + " integer(kind=i_def), intent(in) :: gamma_m_cma_op_3\n" + " integer(kind=i_def), intent(in) :: gamma_p_cma_op_3\n" + " integer(kind=i_def), dimension(ndf_aspc1), intent(in) :: " + "dofmap_aspc1\n" + " integer(kind=i_def), dimension(ndf_aspc1,nlayers), intent(in) :: " + "cbanded_map_aspc1_cma_op_3\n" + " integer(kind=i_def), dimension(ndf_aspc2,nlayers), intent(in) :: " + "cbanded_map_aspc2_cma_op_3\n" + " end subroutine columnwise_op_asm_field_kernel_code\n" + ) + if expected not in result: + print_diffs(expected, result) + assert 0 def test_cma_asm_same_fs_stub_gen(): diff --git a/src/psyclone/tests/dynamo0p3_stubgen_test.py b/src/psyclone/tests/dynamo0p3_stubgen_test.py index aca20fdda1..32147a08a5 100644 --- a/src/psyclone/tests/dynamo0p3_stubgen_test.py +++ b/src/psyclone/tests/dynamo0p3_stubgen_test.py @@ -85,44 +85,44 @@ def test_kernel_stub_invalid_iteration_space(): "'testkern_dofs_code'." in str(excinfo.value)) -def test_stub_generate_with_anyw2(): +def test_stub_generate_with_anyw2(fortran_writer): '''check that the stub generate produces the expected output when we have any_w2 fields. In particular, check basis functions as these have specific sizes associated with the particular function space''' - result = generate(os.path.join(BASE_PATH, - "testkern_multi_anyw2_basis_mod.f90"), - api=TEST_API) - expected_output = ( - " REAL(KIND=r_def), intent(in), dimension(3,ndf_any_w2," - "np_xy_qr_xyoz,np_z_qr_xyoz) :: basis_any_w2_qr_xyoz\n" - " REAL(KIND=r_def), intent(in), dimension(1,ndf_any_w2," - "np_xy_qr_xyoz,np_z_qr_xyoz) :: diff_basis_any_w2_qr_xyoz") - assert expected_output in str(result) - - -SIMPLE = ( - " MODULE simple_mod\n" - " IMPLICIT NONE\n" - " CONTAINS\n" - " SUBROUTINE simple_code(nlayers, field_1_w1, ndf_w1, undf_w1," - " map_w1)\n" - " USE constants_mod\n" - " IMPLICIT NONE\n" - " INTEGER(KIND=i_def), intent(in) :: nlayers\n" - " INTEGER(KIND=i_def), intent(in) :: ndf_w1\n" - " INTEGER(KIND=i_def), intent(in), dimension(ndf_w1) :: map_w1\n" - " INTEGER(KIND=i_def), intent(in) :: undf_w1\n" - " REAL(KIND=r_def), intent(inout), dimension(undf_w1) :: " - "field_1_w1\n" - " END SUBROUTINE simple_code\n" - " END MODULE simple_mod") - - -def test_stub_generate_working(): + psyir = generate(os.path.join(BASE_PATH, + "testkern_multi_anyw2_basis_mod.f90"), + api=TEST_API) + result = fortran_writer(psyir) + + assert ("real(kind=r_def), dimension(3_i_def,ndf_any_w2,np_xy_qr_xyoz," + "np_z_qr_xyoz), intent(in) :: basis_any_w2_qr_xyoz\n" in result) + assert ("real(kind=r_def), dimension(1_i_def,ndf_any_w2,np_xy_qr_xyoz," + "np_z_qr_xyoz), intent(in) :: diff_basis_any_w2_qr_xyoz" in result) + + +SIMPLE = [ + "module simple_mod", + "implicit none", + "contains", + "subroutine simple_code(nlayers, rfield_1_w1, ndf_w1, undf_w1, dofmap_w1)", + "use constants_mod", + "implicit none", + "integer(kind=i_def), intent(in) :: nlayers", + "integer(kind=i_def), intent(in) :: ndf_w1", + "integer(kind=i_def), dimension(ndf_w1), intent(in) :: dofmap_w1", + "integer(kind=i_def), intent(in) :: undf_w1", + "real(kind=r_def), dimension(undf_w1), intent(inout) :: rfield_1_w1", + "end subroutine simple_code", + "end module simple_mod"] + + +def test_stub_generate_working(fortran_writer): ''' Check that the stub generate produces the expected output ''' result = generate(os.path.join(BASE_PATH, "testkern_simple_mod.f90"), api=TEST_API) - assert SIMPLE in str(result) + out = fortran_writer(result) + for line in SIMPLE: + assert line in out # Fields : intent diff --git a/src/psyclone/tests/gen_kernel_stub_test.py b/src/psyclone/tests/gen_kernel_stub_test.py index ea0801d77b..1c1719f030 100644 --- a/src/psyclone/tests/gen_kernel_stub_test.py +++ b/src/psyclone/tests/gen_kernel_stub_test.py @@ -39,8 +39,7 @@ import os import pytest -import fparser - +from psyclone.psyir import nodes from psyclone.errors import GenerationError from psyclone.gen_kernel_stub import generate from psyclone.parse.algorithm import ParseError @@ -72,6 +71,6 @@ def test_gen_success(): ''' Test for successful completion of the generate() function. ''' base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files", "dynamo0p3") - ptree = generate(os.path.join(base_path, "testkern_mod.F90"), + psyir = generate(os.path.join(base_path, "testkern_mod.F90"), api="lfric") - assert isinstance(ptree, fparser.one.block_statements.Module) + assert isinstance(psyir, nodes.Container) diff --git a/src/psyclone/tests/kernel_tools_test.py b/src/psyclone/tests/kernel_tools_test.py index f135640db5..3eb525a196 100644 --- a/src/psyclone/tests/kernel_tools_test.py +++ b/src/psyclone/tests/kernel_tools_test.py @@ -65,7 +65,8 @@ def test_run_default_mode(capsys): "test_files", "dynamo0p3", "testkern_w0_mod.f90") kernel_tools.run([str(kern_file), "-api", "lfric"]) out, err = capsys.readouterr() - assert "Kernel-stub code:\n MODULE testkern_w0_mod\n" in out + assert "Kernel-stub code:\n" in out + assert "module testkern_w0_mod\n" in out assert not err @@ -79,7 +80,7 @@ def test_run(capsys, tmpdir): "-gen", "stub"]) result, _ = capsys.readouterr() assert "Kernel-stub code:" in result - assert "MODULE testkern_w0_mod" in result + assert "module testkern_w0_mod" in result # Test without --limit, but with -o: psy_file = tmpdir.join("psy.f90") @@ -90,7 +91,7 @@ def test_run(capsys, tmpdir): # Now read output file into a string and check: with psy_file.open("r") as psy: output = psy.read() - assert "MODULE testkern_w0_mod" in str(output) + assert "module testkern_w0_mod" in str(output) def test_run_version(capsys): @@ -172,7 +173,7 @@ def test_run_line_length(fortran_reader, monkeypatch, capsys, limit, mode): ''' Check that line-length limiting is applied to generated algorithm and kernel-stub code when requested. ''' - def long_psyir_gen(_1, _2, _3): + def long_psyir_gen(_1, _2=None, api=""): ''' Function that returns PSyIR containing a line longer than 132 chars. ''' routine = Routine.create("my_sub", SymbolTable(), []) @@ -184,16 +185,10 @@ def long_psyir_gen(_1, _2, _3): container = Container.create("my_mod", SymbolTable(), [routine]) return container - def long_gen(_1, api=None): - ''' generate() function that returns a string longer than - 132 chars. ''' - # pylint: disable=unused-argument - return f"long_str = '{140*' '}'" - # Monkeypatch both the algorithm and stub creation functions. monkeypatch.setattr(algorithm.lfric_alg.LFRicAlg, "create_from_kernel", long_psyir_gen) - monkeypatch.setattr(gen_kernel_stub, "generate", long_gen) + monkeypatch.setattr(gen_kernel_stub, "generate", long_psyir_gen) args = ["-gen", mode, str("/does_not_exist"), "-api", "lfric"] if limit: args.extend(["--limit", "output"]) @@ -212,7 +207,7 @@ def test_file_output(fortran_reader, monkeypatch, mode): ''' Check that the output of the generate() function is written to file if requested. We test for both the kernel-stub & algorithm generation. ''' - def fake_psyir_gen(_1, _2, _3): + def fake_psyir_gen(_1, _2=None, api=None): '''Returns PSyIR for a module containing a particular string for testing purposes.''' return fortran_reader.psyir_from_source( @@ -226,15 +221,10 @@ def fake_psyir_gen(_1, _2, _3): end module ''') - def fake_gen(_1, api=None): - ''' generate() function that simply returns a string. ''' - # pylint: disable=unused-argument - return "the_answer = 42" - # Monkeypatch both the algorithm and stub creation functions. monkeypatch.setattr(algorithm.lfric_alg.LFRicAlg, "create_from_kernel", fake_psyir_gen) - monkeypatch.setattr(gen_kernel_stub, "generate", fake_gen) + monkeypatch.setattr(gen_kernel_stub, "generate", fake_psyir_gen) kernel_tools.run(["-api", "lfric", "-gen", mode, "-o", f"output_file_{mode}", str("/does_not_exist")])