Skip to content

Commit

Permalink
fix pysb
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Feb 3, 2025
1 parent 66e541b commit 1fb536d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 24 deletions.
48 changes: 33 additions & 15 deletions python/sdist/amici/petab/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


def _add_observation_model(
pysb_model: pysb.Model, petab_problem: petab.Problem
pysb_model: pysb.Model, petab_problem: petab.Problem, jax: bool = False
):
"""Extend PySB model by observation model as defined in the PEtab
observables table"""
Expand All @@ -39,22 +39,40 @@ def _add_observation_model(
for comp in pysb_model.components
if isinstance(comp, sp.Symbol)
}
for formula in [
*petab_problem.observable_df[OBSERVABLE_FORMULA],
*petab_problem.observable_df[NOISE_FORMULA],
]:
sym = sp.sympify(formula, locals=local_syms)
for s in sym.free_symbols:
if not isinstance(s, pysb.Component):
p = pysb.Parameter(str(s), 1.0)
pysb_model.add_component(p)
local_syms[sp.Symbol.__str__(p)] = p
obs_df = petab_problem.observable_df.copy()
for col, placeholder_pattern in (

Check warning on line 43 in python/sdist/amici/petab/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/pysb_import.py#L42-L43

Added lines #L42 - L43 were not covered by tests
(OBSERVABLE_FORMULA, r"^(observableParameter\d+)_\w+$"),
(NOISE_FORMULA, r"^(noiseParameter\d+)_\w+$"),
):
for ir, formula in petab_problem.observable_df[col].items():
sym = sp.sympify(formula, locals=local_syms)
for s in sym.free_symbols:
if not isinstance(s, pysb.Component):
if jax:
name = re.sub(placeholder_pattern, r"\1", str(s))

Check warning on line 52 in python/sdist/amici/petab/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/pysb_import.py#L47-L52

Added lines #L47 - L52 were not covered by tests
else:
name = str(s)
p = pysb.Parameter(name, 1.0)
pysb_model.add_component(p)

Check warning on line 56 in python/sdist/amici/petab/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/pysb_import.py#L54-L56

Added lines #L54 - L56 were not covered by tests

# placeholders for multiple observables are mapped to the same symbol, so only add to local_syms
# when necessary
if name not in local_syms:
local_syms[name] = p

Check warning on line 61 in python/sdist/amici/petab/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/pysb_import.py#L60-L61

Added lines #L60 - L61 were not covered by tests

# replace placeholder with parameter
if jax and name != str(s):
sym = sym.subs(s, local_syms[name])

Check warning on line 65 in python/sdist/amici/petab/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/pysb_import.py#L64-L65

Added lines #L64 - L65 were not covered by tests

# update forum
if jax:
obs_df.at[ir, col] = sym

Check warning on line 69 in python/sdist/amici/petab/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/pysb_import.py#L68-L69

Added lines #L68 - L69 were not covered by tests

# add observables and sigmas to pysb model
for observable_id, observable_formula, noise_formula in zip(
petab_problem.observable_df.index,
petab_problem.observable_df[OBSERVABLE_FORMULA],
petab_problem.observable_df[NOISE_FORMULA],
obs_df.index,
obs_df[OBSERVABLE_FORMULA],
obs_df[NOISE_FORMULA],
strict=True,
):
obs_symbol = sp.sympify(observable_formula, locals=local_syms)
Expand Down Expand Up @@ -210,7 +228,7 @@ def import_model_pysb(
name=petab_problem.model.model_id,
)

_add_observation_model(pysb_model, petab_problem)
_add_observation_model(pysb_model, petab_problem, jax)

Check warning on line 231 in python/sdist/amici/petab/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/petab/pysb_import.py#L231

Added line #L231 was not covered by tests
# generate species for the _original_ model
pysb.bng.generate_equations(petab_problem.model.model)
fixed_parameters = _add_initialization_variables(pysb_model, petab_problem)
Expand Down
40 changes: 31 additions & 9 deletions python/sdist/amici/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
import logging
import os
import re

Check warning on line 11 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L11

Added line #L11 was not covered by tests
import sys
from pathlib import Path
from typing import (
Expand All @@ -33,6 +34,7 @@
SigmaY,
)
from .de_model import DEModel
from .de_model_components import NoiseParameter, ObservableParameter

Check warning on line 37 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L37

Added line #L37 was not covered by tests
from .import_utils import (
_get_str_symbol_identifiers,
_parse_special_functions,
Expand Down Expand Up @@ -137,6 +139,7 @@ def pysb2jax(
simplify=simplify,
cache_simplify=cache_simplify,
verbose=verbose,
jax=True,
)

from amici.jax.ode_export import ODEExporter
Expand Down Expand Up @@ -300,6 +303,7 @@ def ode_model_from_pysb_importer(
# See https://github.com/AMICI-dev/AMICI/pull/1672
cache_simplify: bool = False,
verbose: int | bool = False,
jax: bool = False,
) -> DEModel:
"""
Creates an :class:`amici.DEModel` instance from a :class:`pysb.Model`
Expand Down Expand Up @@ -335,6 +339,9 @@ def ode_model_from_pysb_importer(
:param verbose: verbosity level for logging, True/False default to
:attr:`logging.DEBUG`/:attr:`logging.ERROR`
:param jax:
if set to ``True``, the generated model will be compatible with JAX export
:return:
New DEModel instance according to pysbModel
"""
Expand All @@ -357,7 +364,7 @@ def ode_model_from_pysb_importer(
pysb.bng.generate_equations(model, verbose=verbose)

_process_pysb_species(model, ode)
_process_pysb_parameters(model, ode, constant_parameters)
_process_pysb_parameters(model, ode, constant_parameters, jax)

Check warning on line 367 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L367

Added line #L367 was not covered by tests
if compute_conservation_laws:
_process_pysb_conservation_laws(model, ode)
_process_pysb_observables(
Expand Down Expand Up @@ -510,7 +517,10 @@ def _process_pysb_species(pysb_model: pysb.Model, ode_model: DEModel) -> None:

@log_execution_time("processing PySB parameters", logger)
def _process_pysb_parameters(
pysb_model: pysb.Model, ode_model: DEModel, constant_parameters: list[str]
pysb_model: pysb.Model,
ode_model: DEModel,
constant_parameters: list[str],
jax: bool = False,
) -> None:
"""
Converts pysb parameters into Parameters or Constants and adds them to
Expand All @@ -522,16 +532,26 @@ def _process_pysb_parameters(
:param constant_parameters:
list of Parameters that should be constants
:param jax:
if set to ``True``, the generated model will be compatible JAX export
:param ode_model:
DEModel instance
"""
for par in pysb_model.parameters:
args = [par, f"{par.name}"]

Check warning on line 542 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L542

Added line #L542 was not covered by tests
if par.name in constant_parameters:
comp = Constant
args.append(par.value)
elif jax and re.match(r"noiseParameter\d+", par.name):
comp = NoiseParameter
elif jax and re.match(r"observableParameter\d+", par.name):
comp = ObservableParameter

Check warning on line 549 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L545-L549

Added lines #L545 - L549 were not covered by tests
else:
comp = Parameter
args.append(par.value)

Check warning on line 552 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L552

Added line #L552 was not covered by tests

ode_model.add_component(comp(par, f"{par.name}", par.value))
ode_model.add_component(comp(*args))

Check warning on line 554 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L554

Added line #L554 was not covered by tests


@log_execution_time("processing PySB expressions", logger)
Expand Down Expand Up @@ -635,11 +655,11 @@ def _add_expression(
:param ode_model:
see :py:func:`_process_pysb_expressions`
"""
ode_model.add_component(
Expression(sym, name, _parse_special_functions(expr))
)

if name in observables:
if name not in observables:
ode_model.add_component(

Check warning on line 659 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L658-L659

Added lines #L658 - L659 were not covered by tests
Expression(sym, name, _parse_special_functions(expr))
)
else:
noise_dist = (
noise_distributions.get(name, "normal")
if noise_distributions
Expand All @@ -648,7 +668,9 @@ def _add_expression(

y = sp.Symbol(f"{name}")
trafo = noise_distribution_to_observable_transformation(noise_dist)
obs = Observable(y, name, sym, transformation=trafo)
obs = Observable(

Check warning on line 671 in python/sdist/amici/pysb_import.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/pysb_import.py#L671

Added line #L671 was not covered by tests
y, name, _parse_special_functions(expr), transformation=trafo
)
ode_model.add_component(obs)

sigma_name, sigma_value = _get_sigma_name_and_value(
Expand Down

0 comments on commit 1fb536d

Please sign in to comment.