From 85e35ae5082da3fc96a4e088d18adefae3753135 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 28 Feb 2024 10:25:19 +0100 Subject: [PATCH] Refactor de_export.py, extract _codegen.model_class (#2319) Move functionality for generating the ``amici::Model`` subclass code to a separate file. Related to #2306 --- python/sdist/amici/_codegen/model_class.py | 191 ++++++++++++++++++++ python/sdist/amici/de_export.py | 195 ++------------------- 2 files changed, 201 insertions(+), 185 deletions(-) create mode 100644 python/sdist/amici/_codegen/model_class.py diff --git a/python/sdist/amici/_codegen/model_class.py b/python/sdist/amici/_codegen/model_class.py new file mode 100644 index 0000000000..e6366c1dfd --- /dev/null +++ b/python/sdist/amici/_codegen/model_class.py @@ -0,0 +1,191 @@ +"""Function for generating the ``amici::Model`` subclass for an amici model.""" +from __future__ import annotations + +from .cxx_functions import functions, multiobs_functions + +from ..de_model_components import Event + + +def get_function_extern_declaration(fun: str, name: str, ode: bool) -> str: + """ + Constructs the extern function declaration for a given function + + :param fun: + function name + :param name: + model name + :param ode: + whether to generate declaration for DAE or ODE + + :return: + C++ function definition string + """ + f = functions[fun] + return f"extern {f.return_type} {fun}_{name}({f.arguments(ode)});" + + +def get_sunindex_extern_declaration( + fun: str, name: str, indextype: str +) -> str: + """ + Constructs the function declaration for an index function of a given + function + + :param fun: + function name + + :param name: + model name + + :param indextype: + index function {'colptrs', 'rowvals'} + + :return: + C++ function declaration string + """ + index_arg = ", int index" if fun in multiobs_functions else "" + return ( + f"extern void {fun}_{indextype}_{name}" + f"(SUNMatrixWrapper &{indextype}{index_arg});" + ) + + +def get_model_override_implementation( + fun: str, name: str, ode: bool, nobody: bool = False +) -> str: + """ + Constructs ``amici::Model::*`` override implementation for a given function + + :param fun: + function name + + :param name: + model name + + :param nobody: + whether the function has a nontrivial implementation + + :return: + C++ function implementation string + """ + func_info = functions[fun] + body = ( + "" + if nobody + else "\n{ind8}{maybe_return}{fun}_{name}({eval_signature});\n{ind4}".format( + ind4=" " * 4, + ind8=" " * 8, + maybe_return="" if func_info.return_type == "void" else "return ", + fun=fun, + name=name, + eval_signature=remove_argument_types(func_info.arguments(ode)), + ) + ) + return "{return_type} f{fun}({signature}) override {{{body}}}\n".format( + return_type=func_info.return_type, + fun=fun, + signature=func_info.arguments(ode), + body=body, + ) + + +def get_sunindex_override_implementation( + fun: str, name: str, indextype: str, nobody: bool = False +) -> str: + """ + Constructs the ``amici::Model`` function implementation for an index + function of a given function + + :param fun: + function name + + :param name: + model name + + :param indextype: + index function {'colptrs', 'rowvals'} + + :param nobody: + whether the corresponding function has a nontrivial implementation + + :return: + C++ function implementation string + """ + index_arg = ", int index" if fun in multiobs_functions else "" + index_arg_eval = ", index" if fun in multiobs_functions else "" + + impl = "void f{fun}_{indextype}({signature}) override {{" + + if nobody: + impl += "}}\n" + else: + impl += ( + "\n{ind8}{fun}_{indextype}_{name}({eval_signature});\n{ind4}}}\n" + ) + + return impl.format( + ind4=" " * 4, + ind8=" " * 8, + fun=fun, + indextype=indextype, + name=name, + signature=f"SUNMatrixWrapper &{indextype}{index_arg}", + eval_signature=f"{indextype}{index_arg_eval}", + ) + + +def remove_argument_types(signature: str) -> str: + """ + Strips argument types from a function signature + + :param signature: + function signature + + :return: + string that can be used to construct function calls with the same + variable names and ordering as in the function signature + """ + # remove * prefix for pointers (pointer must always be removed before + # values otherwise we will inadvertently dereference values, + # same applies for const specifications) + # + # always add whitespace after type definition for cosmetic reasons + known_types = [ + "const realtype *", + "const double *", + "const realtype ", + "double *", + "realtype *", + "const int ", + "int ", + "bool ", + "SUNMatrixContent_Sparse ", + "gsl::span", + ] + + for type_str in known_types: + signature = signature.replace(type_str, "") + + return signature + + +def get_state_independent_event_intializer(events: list[Event]) -> str: + """Get initializer list for state independent events in amici::Model.""" + map_time_to_event_idx = {} + for event_idx, event in enumerate(events): + if not event.triggers_at_fixed_timepoint(): + continue + trigger_time = float(event.get_trigger_time()) + try: + map_time_to_event_idx[trigger_time].append(event_idx) + except KeyError: + map_time_to_event_idx[trigger_time] = [event_idx] + + def vector_initializer(v): + """std::vector initializer list with elements from `v`""" + return f"{{{', '.join(map(str, v))}}}" + + return ", ".join( + f"{{{trigger_time}, {vector_initializer(event_idxs)}}}" + for trigger_time, event_idxs in map_time_to_event_idx.items() + ) diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index d956056a2e..cdf5d6ac6a 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -49,6 +49,13 @@ event_sensi_functions, multiobs_functions, ) +from ._codegen.model_class import ( + get_function_extern_declaration, + get_sunindex_extern_declaration, + get_model_override_implementation, + get_sunindex_override_implementation, + get_state_independent_event_intializer, +) from ._codegen.template import apply_template from .cxxcodeprinter import ( AmiciCxxCodePrinter, @@ -3330,7 +3337,9 @@ def _write_model_header_cpp(self) -> None: ) ), "Z2EVENT": ", ".join(map(str, self.model._z2event)), - "STATE_INDEPENDENT_EVENTS": self._get_state_independent_event_intializer(), + "STATE_INDEPENDENT_EVENTS": get_state_independent_event_intializer( + self.model.events() + ), "ID": ", ".join( str(float(isinstance(s, DifferentialState))) for s in self.model.states() @@ -3464,27 +3473,6 @@ def _get_symbol_id_initializer_list(self, name: str) -> str: for idx, symbol in enumerate(self.model.sym(name)) ) - def _get_state_independent_event_intializer(self) -> str: - """Get initializer list for state independent events in amici::Model.""" - map_time_to_event_idx = {} - for event_idx, event in enumerate(self.model.events()): - if not event.triggers_at_fixed_timepoint(): - continue - trigger_time = float(event.get_trigger_time()) - try: - map_time_to_event_idx[trigger_time].append(event_idx) - except KeyError: - map_time_to_event_idx[trigger_time] = [event_idx] - - def vector_initializer(v): - """std::vector initializer list with elements from `v`""" - return f"{{{', '.join(map(str, v))}}}" - - return ", ".join( - f"{{{trigger_time}, {vector_initializer(event_idxs)}}}" - for trigger_time, event_idxs in map_time_to_event_idx.items() - ) - def _write_c_make_file(self): """Write CMake ``CMakeLists.txt`` file for this model.""" sources = "\n".join( @@ -3587,169 +3575,6 @@ def set_name(self, model_name: str) -> None: self.model_name = model_name -def get_function_extern_declaration(fun: str, name: str, ode: bool) -> str: - """ - Constructs the extern function declaration for a given function - - :param fun: - function name - :param name: - model name - :param ode: - whether to generate declaration for DAE or ODE - - :return: - C++ function definition string - """ - f = functions[fun] - return f"extern {f.return_type} {fun}_{name}({f.arguments(ode)});" - - -def get_sunindex_extern_declaration( - fun: str, name: str, indextype: str -) -> str: - """ - Constructs the function declaration for an index function of a given - function - - :param fun: - function name - - :param name: - model name - - :param indextype: - index function {'colptrs', 'rowvals'} - - :return: - C++ function declaration string - """ - index_arg = ", int index" if fun in multiobs_functions else "" - return ( - f"extern void {fun}_{indextype}_{name}" - f"(SUNMatrixWrapper &{indextype}{index_arg});" - ) - - -def get_model_override_implementation( - fun: str, name: str, ode: bool, nobody: bool = False -) -> str: - """ - Constructs ``amici::Model::*`` override implementation for a given function - - :param fun: - function name - - :param name: - model name - - :param nobody: - whether the function has a nontrivial implementation - - :return: - C++ function implementation string - """ - func_info = functions[fun] - body = ( - "" - if nobody - else "\n{ind8}{maybe_return}{fun}_{name}({eval_signature});\n{ind4}".format( - ind4=" " * 4, - ind8=" " * 8, - maybe_return="" if func_info.return_type == "void" else "return ", - fun=fun, - name=name, - eval_signature=remove_argument_types(func_info.arguments(ode)), - ) - ) - return "{return_type} f{fun}({signature}) override {{{body}}}\n".format( - return_type=func_info.return_type, - fun=fun, - signature=func_info.arguments(ode), - body=body, - ) - - -def get_sunindex_override_implementation( - fun: str, name: str, indextype: str, nobody: bool = False -) -> str: - """ - Constructs the ``amici::Model`` function implementation for an index - function of a given function - - :param fun: - function name - - :param name: - model name - - :param indextype: - index function {'colptrs', 'rowvals'} - - :param nobody: - whether the corresponding function has a nontrivial implementation - - :return: - C++ function implementation string - """ - index_arg = ", int index" if fun in multiobs_functions else "" - index_arg_eval = ", index" if fun in multiobs_functions else "" - - impl = "void f{fun}_{indextype}({signature}) override {{" - - if nobody: - impl += "}}\n" - else: - impl += ( - "\n{ind8}{fun}_{indextype}_{name}({eval_signature});\n{ind4}}}\n" - ) - - return impl.format( - ind4=" " * 4, - ind8=" " * 8, - fun=fun, - indextype=indextype, - name=name, - signature=f"SUNMatrixWrapper &{indextype}{index_arg}", - eval_signature=f"{indextype}{index_arg_eval}", - ) - - -def remove_argument_types(signature: str) -> str: - """ - Strips argument types from a function signature - - :param signature: - function signature - - :return: - string that can be used to construct function calls with the same - variable names and ordering as in the function signature - """ - # remove * prefix for pointers (pointer must always be removed before - # values otherwise we will inadvertently dereference values, - # same applies for const specifications) - # - # always add whitespace after type definition for cosmetic reasons - known_types = [ - "const realtype *", - "const double *", - "const realtype ", - "double *", - "realtype *", - "const int ", - "int ", - "bool ", - "SUNMatrixContent_Sparse ", - "gsl::span", - ] - - for type_str in known_types: - signature = signature.replace(type_str, "") - - return signature - - def is_valid_identifier(x: str) -> bool: """ Check whether `x` is a valid identifier for conditions, parameters,