diff --git a/python/sdist/amici/petab/pysb_import.py b/python/sdist/amici/petab/pysb_import.py index 32de3d6666..97861674a5 100644 --- a/python/sdist/amici/petab/pysb_import.py +++ b/python/sdist/amici/petab/pysb_import.py @@ -28,7 +28,7 @@ def _add_observation_model( - pysb_model: pysb.Model, petab_problem: petab.Problem + pysb_model: pysb.Model, petab_problem: petab.Problem, jax: bool = False ): """Extend PySB model by observation model as defined in the PEtab observables table""" @@ -39,22 +39,40 @@ def _add_observation_model( for comp in pysb_model.components if isinstance(comp, sp.Symbol) } - for formula in [ - *petab_problem.observable_df[OBSERVABLE_FORMULA], - *petab_problem.observable_df[NOISE_FORMULA], - ]: - sym = sp.sympify(formula, locals=local_syms) - for s in sym.free_symbols: - if not isinstance(s, pysb.Component): - p = pysb.Parameter(str(s), 1.0) - pysb_model.add_component(p) - local_syms[sp.Symbol.__str__(p)] = p + obs_df = petab_problem.observable_df.copy() + for col, placeholder_pattern in ( + (OBSERVABLE_FORMULA, r"^(observableParameter\d+)_\w+$"), + (NOISE_FORMULA, r"^(noiseParameter\d+)_\w+$"), + ): + for ir, formula in petab_problem.observable_df[col].items(): + sym = sp.sympify(formula, locals=local_syms) + for s in sym.free_symbols: + if not isinstance(s, pysb.Component): + if jax: + name = re.sub(placeholder_pattern, r"\1", str(s)) + else: + name = str(s) + p = pysb.Parameter(name, 1.0) + pysb_model.add_component(p) + + # placeholders for multiple observables are mapped to the same symbol, so only add to local_syms + # when necessary + if name not in local_syms: + local_syms[name] = p + + # replace placeholder with parameter + if jax and name != str(s): + sym = sym.subs(s, local_syms[name]) + + # update forum + if jax: + obs_df.at[ir, col] = sym # add observables and sigmas to pysb model for observable_id, observable_formula, noise_formula in zip( - petab_problem.observable_df.index, - petab_problem.observable_df[OBSERVABLE_FORMULA], - petab_problem.observable_df[NOISE_FORMULA], + obs_df.index, + obs_df[OBSERVABLE_FORMULA], + obs_df[NOISE_FORMULA], strict=True, ): obs_symbol = sp.sympify(observable_formula, locals=local_syms) @@ -210,7 +228,7 @@ def import_model_pysb( name=petab_problem.model.model_id, ) - _add_observation_model(pysb_model, petab_problem) + _add_observation_model(pysb_model, petab_problem, jax) # generate species for the _original_ model pysb.bng.generate_equations(petab_problem.model.model) fixed_parameters = _add_initialization_variables(pysb_model, petab_problem) diff --git a/python/sdist/amici/pysb_import.py b/python/sdist/amici/pysb_import.py index b84fadea44..cb1c7f17f6 100644 --- a/python/sdist/amici/pysb_import.py +++ b/python/sdist/amici/pysb_import.py @@ -8,6 +8,7 @@ import itertools import logging import os +import re import sys from pathlib import Path from typing import ( @@ -33,6 +34,7 @@ SigmaY, ) from .de_model import DEModel +from .de_model_components import NoiseParameter, ObservableParameter from .import_utils import ( _get_str_symbol_identifiers, _parse_special_functions, @@ -137,6 +139,7 @@ def pysb2jax( simplify=simplify, cache_simplify=cache_simplify, verbose=verbose, + jax=True, ) from amici.jax.ode_export import ODEExporter @@ -300,6 +303,7 @@ def ode_model_from_pysb_importer( # See https://github.com/AMICI-dev/AMICI/pull/1672 cache_simplify: bool = False, verbose: int | bool = False, + jax: bool = False, ) -> DEModel: """ Creates an :class:`amici.DEModel` instance from a :class:`pysb.Model` @@ -335,6 +339,9 @@ def ode_model_from_pysb_importer( :param verbose: verbosity level for logging, True/False default to :attr:`logging.DEBUG`/:attr:`logging.ERROR` + :param jax: + if set to ``True``, the generated model will be compatible with JAX export + :return: New DEModel instance according to pysbModel """ @@ -357,7 +364,7 @@ def ode_model_from_pysb_importer( pysb.bng.generate_equations(model, verbose=verbose) _process_pysb_species(model, ode) - _process_pysb_parameters(model, ode, constant_parameters) + _process_pysb_parameters(model, ode, constant_parameters, jax) if compute_conservation_laws: _process_pysb_conservation_laws(model, ode) _process_pysb_observables( @@ -510,7 +517,10 @@ def _process_pysb_species(pysb_model: pysb.Model, ode_model: DEModel) -> None: @log_execution_time("processing PySB parameters", logger) def _process_pysb_parameters( - pysb_model: pysb.Model, ode_model: DEModel, constant_parameters: list[str] + pysb_model: pysb.Model, + ode_model: DEModel, + constant_parameters: list[str], + jax: bool = False, ) -> None: """ Converts pysb parameters into Parameters or Constants and adds them to @@ -522,16 +532,26 @@ def _process_pysb_parameters( :param constant_parameters: list of Parameters that should be constants + :param jax: + if set to ``True``, the generated model will be compatible JAX export + :param ode_model: DEModel instance """ for par in pysb_model.parameters: + args = [par, f"{par.name}"] if par.name in constant_parameters: comp = Constant + args.append(par.value) + elif jax and re.match(r"noiseParameter\d+", par.name): + comp = NoiseParameter + elif jax and re.match(r"observableParameter\d+", par.name): + comp = ObservableParameter else: comp = Parameter + args.append(par.value) - ode_model.add_component(comp(par, f"{par.name}", par.value)) + ode_model.add_component(comp(*args)) @log_execution_time("processing PySB expressions", logger) @@ -635,11 +655,11 @@ def _add_expression( :param ode_model: see :py:func:`_process_pysb_expressions` """ - ode_model.add_component( - Expression(sym, name, _parse_special_functions(expr)) - ) - - if name in observables: + if name not in observables: + ode_model.add_component( + Expression(sym, name, _parse_special_functions(expr)) + ) + else: noise_dist = ( noise_distributions.get(name, "normal") if noise_distributions @@ -648,7 +668,9 @@ def _add_expression( y = sp.Symbol(f"{name}") trafo = noise_distribution_to_observable_transformation(noise_dist) - obs = Observable(y, name, sym, transformation=trafo) + obs = Observable( + y, name, _parse_special_functions(expr), transformation=trafo + ) ode_model.add_component(obs) sigma_name, sigma_value = _get_sigma_name_and_value(