Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle reserved names during code-printing #2483

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class AmiciCxxCodePrinter(CXX11CodePrinter):
"""

optimizations: Iterable[Optimization] = ()
RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"]

def __init__(self):
"""Create code printer"""
Expand Down Expand Up @@ -67,6 +68,12 @@ def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str:
f'Encountered unsupported function in expression "{expr}"'
) from e

def _print_Symbol(self, expr):
name = super()._print_Symbol(expr)
if name in self.RESERVED_SYMBOLS:
return f"amici_{name}"
return name

def _print_min_max(self, expr, cpp_fun: str, sympy_fun):
# C++ doesn't like mixing int and double for arguments for min/max,
# therefore, we just always convert to float
Expand Down
8 changes: 7 additions & 1 deletion python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ def _write_index_files(self, name: str) -> None:
lines = []
for index, symbol in enumerate(symbols):
symbol_name = strip_pysb(symbol)
# symbol_name is a mix of symbols and strings
symbol_name = self._code_printer._print_Symbol(
sp.Symbol(symbol_name)
if isinstance(symbol_name, str)
else symbol_name
)
if str(symbol) == "0":
continue
if str(symbol_name) == "":
Expand Down Expand Up @@ -1221,7 +1227,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
Template initializer list of ids
"""
return "\n".join(
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
f'"{symbol}", // {name}[{idx}]'
for idx, symbol in enumerate(self.model.sym(name))
)

Expand Down
8 changes: 0 additions & 8 deletions python/sdist/amici/de_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sympy as sp

from .import_utils import (
RESERVED_SYMBOLS,
ObservableTransformation,
amici_time_symbol,
cast_to_sym,
Expand Down Expand Up @@ -66,13 +65,6 @@ def __init__(
f"identifier must be sympy.Symbol, was " f"{type(identifier)}"
)

if str(identifier) in RESERVED_SYMBOLS or (
hasattr(identifier, "name") and identifier.name in RESERVED_SYMBOLS
):
raise ValueError(
f'Cannot add model quantity with name "{name}", '
f"please rename."
)
self._identifier: sp.Symbol = identifier

if not isinstance(name, str):
Expand Down
2 changes: 0 additions & 2 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from sympy.logic.boolalg import BooleanAtom
from toposort import toposort

RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"]

try:
import pysb
except ImportError:
Expand Down
20 changes: 0 additions & 20 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from .de_model_components import symbol_to_type, Expression
from .sympy_utils import smart_is_zero_matrix, smart_multiply
from .import_utils import (
RESERVED_SYMBOLS,
_check_unsupported_functions,
_get_str_symbol_identifiers,
_parse_special_functions,
Expand Down Expand Up @@ -523,7 +522,6 @@ def _build_ode_model(
)
self._replace_compartments_with_volumes()

self._clean_reserved_symbols()
self._process_time()

ode_model = DEModel(
Expand Down Expand Up @@ -2596,24 +2594,6 @@ def _replace_in_all_expressions(
for spline in self.splines:
spline._replace_in_all_expressions(old, new)

def _clean_reserved_symbols(self) -> None:
"""
Remove all reserved symbols from self.symbols
"""
for sym in RESERVED_SYMBOLS:
old_symbol = symbol_with_assumptions(sym)
new_symbol = symbol_with_assumptions(f"amici_{sym}")
self._replace_in_all_expressions(
old_symbol, new_symbol, replace_identifiers=True
)
for symbols_ids, symbols in self.symbols.items():
if old_symbol in symbols:
# reconstitute the whole dict in order to keep the ordering
self.symbols[symbols_ids] = {
new_symbol if k is old_symbol else k: v
for k, v in symbols.items()
}

def _sympy_from_sbml_math(
self, var_or_math: [sbml.SBase, str]
) -> sp.Expr | float | None:
Expand Down
36 changes: 36 additions & 0 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,39 @@ def test_constraints():
amici_solver.getAbsoluteTolerance(),
)
)


def test_reserved_symbols():
"""Test handling of reserved one-letter names."""
from amici.antimony_import import antimony2amici

ant_model = """
model test_non_negative_species
t = 0.1
x = 0.2
y = 0.3
w = 0.4
h = 0.5
p = 0.6
k = 0.7
x' = k + x + p + y + w + h + t
end
"""
module_name = "test_reserved_symbols"
with TemporaryDirectory(prefix=module_name) as outdir:
antimony2amici(
ant_model,
model_name=module_name,
output_dir=outdir,
compute_conservation_laws=False,
)
# ensure it compiled successfully and can be imported
model_module = amici.import_model_module(
module_name=module_name, module_path=outdir
)
model = model_module.get_model()
ids = list(model.getParameterIds())
ids.extend(model.getStateIds())
# all symbols should be present with their original IDs
for symbol in "txywhpk":
assert symbol in ids
Loading