diff --git a/python/sdist/amici/cxxcodeprinter.py b/python/sdist/amici/cxxcodeprinter.py index 69060eb00a..121d79d730 100644 --- a/python/sdist/amici/cxxcodeprinter.py +++ b/python/sdist/amici/cxxcodeprinter.py @@ -30,6 +30,7 @@ class AmiciCxxCodePrinter(CXX11CodePrinter): """ optimizations: Iterable[Optimization] = () + RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"] def __init__(self): """Create code printer""" @@ -67,6 +68,12 @@ def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str: f'Encountered unsupported function in expression "{expr}"' ) from e + def _print_Symbol(self, expr): + name = super()._print_Symbol(expr) + if name in self.RESERVED_SYMBOLS: + return f"amici_{name}" + return name + def _print_min_max(self, expr, cpp_fun: str, sympy_fun): # C++ doesn't like mixing int and double for arguments for min/max, # therefore, we just always convert to float diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 6b1392a3d1..752095c44f 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -394,6 +394,12 @@ def _write_index_files(self, name: str) -> None: lines = [] for index, symbol in enumerate(symbols): symbol_name = strip_pysb(symbol) + # symbol_name is a mix of symbols and strings + symbol_name = self._code_printer._print_Symbol( + sp.Symbol(symbol_name) + if isinstance(symbol_name, str) + else symbol_name + ) if str(symbol) == "0": continue if str(symbol_name) == "": @@ -1221,7 +1227,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str: Template initializer list of ids """ return "\n".join( - f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]' + f'"{symbol}", // {name}[{idx}]' for idx, symbol in enumerate(self.model.sym(name)) ) diff --git a/python/sdist/amici/de_model_components.py b/python/sdist/amici/de_model_components.py index bc93f44b87..507a09cc1d 100644 --- a/python/sdist/amici/de_model_components.py +++ b/python/sdist/amici/de_model_components.py @@ -7,7 +7,6 @@ import sympy as sp from .import_utils import ( - RESERVED_SYMBOLS, ObservableTransformation, amici_time_symbol, cast_to_sym, @@ -66,13 +65,6 @@ def __init__( f"identifier must be sympy.Symbol, was " f"{type(identifier)}" ) - if str(identifier) in RESERVED_SYMBOLS or ( - hasattr(identifier, "name") and identifier.name in RESERVED_SYMBOLS - ): - raise ValueError( - f'Cannot add model quantity with name "{name}", ' - f"please rename." - ) self._identifier: sp.Symbol = identifier if not isinstance(name, str): diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 1a0dc782db..cf4e547fec 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -18,8 +18,6 @@ from sympy.logic.boolalg import BooleanAtom from toposort import toposort -RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"] - try: import pysb except ImportError: diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 61ce9a0ee1..0f3de093e9 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -35,7 +35,6 @@ from .de_model_components import symbol_to_type, Expression from .sympy_utils import smart_is_zero_matrix, smart_multiply from .import_utils import ( - RESERVED_SYMBOLS, _check_unsupported_functions, _get_str_symbol_identifiers, _parse_special_functions, @@ -523,7 +522,6 @@ def _build_ode_model( ) self._replace_compartments_with_volumes() - self._clean_reserved_symbols() self._process_time() ode_model = DEModel( @@ -2596,24 +2594,6 @@ def _replace_in_all_expressions( for spline in self.splines: spline._replace_in_all_expressions(old, new) - def _clean_reserved_symbols(self) -> None: - """ - Remove all reserved symbols from self.symbols - """ - for sym in RESERVED_SYMBOLS: - old_symbol = symbol_with_assumptions(sym) - new_symbol = symbol_with_assumptions(f"amici_{sym}") - self._replace_in_all_expressions( - old_symbol, new_symbol, replace_identifiers=True - ) - for symbols_ids, symbols in self.symbols.items(): - if old_symbol in symbols: - # reconstitute the whole dict in order to keep the ordering - self.symbols[symbols_ids] = { - new_symbol if k is old_symbol else k: v - for k, v in symbols.items() - } - def _sympy_from_sbml_math( self, var_or_math: [sbml.SBase, str] ) -> sp.Expr | float | None: diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index 4936a3c901..aa5a63c534 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -773,3 +773,39 @@ def test_constraints(): amici_solver.getAbsoluteTolerance(), ) ) + + +def test_reserved_symbols(): + """Test handling of reserved one-letter names.""" + from amici.antimony_import import antimony2amici + + ant_model = """ + model test_non_negative_species + t = 0.1 + x = 0.2 + y = 0.3 + w = 0.4 + h = 0.5 + p = 0.6 + k = 0.7 + x' = k + x + p + y + w + h + t + end + """ + module_name = "test_reserved_symbols" + with TemporaryDirectory(prefix=module_name) as outdir: + antimony2amici( + ant_model, + model_name=module_name, + output_dir=outdir, + compute_conservation_laws=False, + ) + # ensure it compiled successfully and can be imported + model_module = amici.import_model_module( + module_name=module_name, module_path=outdir + ) + model = model_module.get_model() + ids = list(model.getParameterIds()) + ids.extend(model.getStateIds()) + # all symbols should be present with their original IDs + for symbol in "txywhpk": + assert symbol in ids