Skip to content

Commit

Permalink
Merge branch 'develop' into ft_2368_partial_equilibration
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl authored Apr 19, 2024
2 parents 858d883 + 09e0581 commit 1edc339
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 38 deletions.
28 changes: 21 additions & 7 deletions documentation/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,33 @@
import subprocess
import sys
from enum import EnumType

# need to import before setting typing.TYPE_CHECKING=True, fails otherwise
import amici
import exhale.deploy
import exhale_multiproject_monkeypatch
from unittest import mock
import pandas as pd
import sphinx
import sympy as sp
from exhale import configs as exhale_configs
from sphinx.transforms.post_transforms import ReferencesResolver

exhale_multiproject_monkeypatch, pd, sp # to avoid removal of unused import
try:
import exhale_multiproject_monkeypatch # noqa: F401
except ModuleNotFoundError:
# for unclear reasons, the import of exhale_multiproject_monkeypatch
# fails on some systems, because the the location of the editable install
# is not automatically added to sys.path ¯\_(ツ)_/¯
from importlib.metadata import Distribution
import json
from urllib.parse import unquote_plus, urlparse

dist = Distribution.from_name("sphinx-contrib-exhale-multiproject")
url = json.loads(dist.read_text("direct_url.json"))["url"]
package_dir = unquote_plus(urlparse(url).path)
sys.path.append(package_dir)
import exhale_multiproject_monkeypatch # noqa: F401

# need to import before setting typing.TYPE_CHECKING=True, fails otherwise
import amici
import pandas as pd # noqa: F401
import sympy as sp # noqa: F401


# BEGIN Monkeypatch exhale
from exhale.deploy import _generate_doxygen as exhale_generate_doxygen
Expand Down
4 changes: 1 addition & 3 deletions documentation/rtd_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# NOTE: relative paths are expected to be relative to the repository root

# sphinx<7.3.0: https://github.com/AMICI-dev/AMICI/issues/2403
sphinx<7.3.0
sphinx
mock>=5.0.2
setuptools>=67.7.2
pysb>=1.11.0
Expand Down
75 changes: 55 additions & 20 deletions python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Functions related to SWIG or SWIG-generated code"""
from __future__ import annotations
import ast
import contextlib
import re


class TypeHintFixer(ast.NodeTransformer):
"""Replaces SWIG-generated C++ typehints by corresponding Python types"""
"""Replaces SWIG-generated C++ typehints by corresponding Python types."""

mapping = {
"void": None,
Expand Down Expand Up @@ -53,9 +54,13 @@ class TypeHintFixer(ast.NodeTransformer):
"std::allocator< amici::ParameterScaling > > const &": ast.Constant(
"ParameterScalingVector"
),
"H5::H5File": None,
}

def visit_FunctionDef(self, node):
# convert type/rtype from docstring to annotation, if possible.
# those may be c++ types, not valid in python, that need to be
# converted to python types below.
self._annotation_from_docstring(node)

# Has a return type annotation?
Expand All @@ -67,14 +72,17 @@ def visit_FunctionDef(self, node):
for arg in node.args.args:
if not arg.annotation:
continue
if isinstance(arg.annotation, ast.Name):
if not isinstance(arg.annotation, ast.Constant):
# there is already proper annotation
continue

arg.annotation = self._new_annot(arg.annotation.value)
return node

def _new_annot(self, old_annot: str):
def _new_annot(self, old_annot: str | ast.Name):
if isinstance(old_annot, ast.Name):
old_annot = old_annot.id

with contextlib.suppress(KeyError):
return self.mapping[old_annot]

Expand Down Expand Up @@ -117,6 +125,8 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
Swig sometimes generates ``:type solver: :py:class:`Solver`` instead of
``:type solver: Solver``. Those need special treatment.
Overloaded functions are skipped.
"""
docstring = ast.get_docstring(node, clean=False)
if not docstring or "*Overload 1:*" in docstring:
Expand All @@ -127,22 +137,18 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
lines_to_remove = set()

for line_no, line in enumerate(docstring):
if (
match := re.match(
r"\s*:rtype:\s*(?::py:class:`)?(\w+)`?\s+$", line
)
) and not match.group(1).startswith(":"):
node.returns = ast.Constant(match.group(1))
if type_str := self.extract_rtype(line):
# handle `:rtype:`
node.returns = ast.Constant(type_str)
lines_to_remove.add(line_no)
continue

if (
match := re.match(
r"\s*:type\s*(\w+):\W*(?::py:class:`)?(\w+)`?\s+$", line
)
) and not match.group(1).startswith(":"):
arg_name, type_str = self.extract_type(line)
if arg_name is not None:
# handle `:type ...:`
for arg in node.args.args:
if arg.arg == match.group(1):
arg.annotation = ast.Constant(match.group(2))
if arg.arg == arg_name:
arg.annotation = ast.Constant(type_str)
lines_to_remove.add(line_no)

if lines_to_remove:
Expand All @@ -155,13 +161,42 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
)
node.body[0].value = ast.Str(new_docstring)

@staticmethod
def extract_type(line: str) -> tuple[str, str] | tuple[None, None]:
"""Extract argument name and type string from ``:type:`` docstring
line."""
match = re.match(r"\s*:type\s+(\w+):\s+(.+?)(?:, optional)?\s*$", line)
if not match:
return None, None

arg_name = match.group(1)

# get rid of any :py:class`...` in the type string if necessary
if not match.group(2).startswith(":py:"):
return arg_name, match.group(2)

match = re.match(r":py:\w+:`(.+)`", match.group(2))
assert match
return arg_name, match.group(1)

@staticmethod
def extract_rtype(line: str) -> str | None:
"""Extract type string from ``:rtype:`` docstring line."""
match = re.match(r"\s*:rtype:\s+(.+)\s*$", line)
if not match:
return None

# get rid of any :py:class`...` in the type string if necessary
if not match.group(1).startswith(":py:"):
return match.group(1)

match = re.match(r":py:\w+:`(.+)`", match.group(1))
assert match
return match.group(1)


def fix_typehints(infilename, outfilename):
"""Change SWIG-generated C++ typehints to Python typehints"""
# Only available from Python3.9
if not getattr(ast, "unparse", None):
return

# file -> AST
with open(infilename) as f:
source = f.read()
Expand Down
3 changes: 0 additions & 3 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,6 @@ int Model::checkFinite(
&& model_quantity != ModelQuantity::ts) {
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
if (!always_check_finite_ && model_quantity != ModelQuantity::w) {
// don't check twice if always_check_finite_ is true
checkFinite(derived_state_.w_, ModelQuantity::w, t);
Expand Down Expand Up @@ -1791,7 +1790,6 @@ int Model::checkFinite(
// check upstream
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
checkFinite(derived_state_.w_, ModelQuantity::w, t);

return AMICI_RECOVERABLE_ERROR;
Expand Down Expand Up @@ -1882,7 +1880,6 @@ int Model::checkFinite(SUNMatrix m, ModelQuantity model_quantity, realtype t)
// check upstream
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
checkFinite(derived_state_.w_, ModelQuantity::w, t);

return AMICI_RECOVERABLE_ERROR;
Expand Down
10 changes: 5 additions & 5 deletions src/solver_idas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
namespace amici {

/*
* The following static members are callback function to CVODES.
* The following static members are callback function to IDAS.
* Their signatures must not be changes.
*/

Expand Down Expand Up @@ -437,7 +437,7 @@ void IDASolver::reInitPostProcess(

auto status = IDASetStopTime(ida_mem, tout);
if (status != IDA_SUCCESS)
throw IDAException(status, "CVodeSetStopTime");
throw IDAException(status, "IDASetStopTime");

status = IDASolve(
ami_mem, tout, t, yout->getNVector(), ypout->getNVector(), IDA_ONE_STEP
Expand Down Expand Up @@ -853,7 +853,7 @@ void IDASolver::setNonLinearSolver() const {
solver_memory_.get(), non_linear_solver_->get()
);
if (status != IDA_SUCCESS)
throw CvodeException(status, "CVodeSetNonlinearSolver");
throw IDAException(status, "IDASetNonlinearSolver");
}

void IDASolver::setNonLinearSolverSens() const {
Expand Down Expand Up @@ -883,15 +883,15 @@ void IDASolver::setNonLinearSolverSens() const {
}

if (status != IDA_SUCCESS)
throw CvodeException(status, "CVodeSolver::setNonLinearSolverSens");
throw IDAException(status, "IDASolver::setNonLinearSolverSens");
}

void IDASolver::setNonLinearSolverB(int which) const {
int status = IDASetNonlinearSolverB(
solver_memory_.get(), which, non_linear_solver_B_->get()
);
if (status != IDA_SUCCESS)
throw CvodeException(status, "CVodeSetNonlinearSolverB");
throw IDAException(status, "IDASetNonlinearSolverB");
}

/**
Expand Down
2 changes: 2 additions & 0 deletions swig/misc.i
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
%ignore amici::regexErrorToString;
%ignore amici::writeSlice;
%ignore ContextManager;
%ignore amici::scaleParameters;
%ignore amici::unscaleParameters;

// Add necessary symbols to generated header
%{
Expand Down

0 comments on commit 1edc339

Please sign in to comment.