From cc620e15061572ea3bccb910bf7eccf7df9635eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 16 May 2024 15:59:32 +0200 Subject: [PATCH] add timeseries class --- python/sdist/amici/swig_wrappers.py | 52 ++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/python/sdist/amici/swig_wrappers.py b/python/sdist/amici/swig_wrappers.py index 2e90891df3..95d1235d38 100644 --- a/python/sdist/amici/swig_wrappers.py +++ b/python/sdist/amici/swig_wrappers.py @@ -73,10 +73,35 @@ def runAmiciSimulation( return numpy.ReturnDataView(rdata) +class ExpDataTimeseries: + """ + Prototype for timeseries data storage. + """ + + _exp_datas: list[amici.ExpData] + + def __len__(self): + return len(self._exp_datas) + + def __init__(self, exp_datas: list[amici.ExpData]): + self._exp_datas = exp_datas + + def __getitem__(self, key): + return next(edata for edata in self._exp_datas if edata.id == key) + + def __setitem__(self, key, value): + self._exp_datas = [ + edata if edata.id != key else value for edata in self._exp_datas + ] + + def get_ptr_vector(self): + return amici_swig.ExpDataPtrVector(self._exp_datas) + + def runAmiciSimulations( model: AmiciModel, solver: AmiciSolver, - edata_list: AmiciExpDataVector, + edata_list: AmiciExpDataVector | ExpDataTimeseries, failfast: bool = True, num_threads: int = 1, ) -> list["numpy.ReturnDataView"]: @@ -105,14 +130,23 @@ def runAmiciSimulations( stacklevel=1, ) - edata_ptr_vector = amici_swig.ExpDataPtrVector(edata_list) - rdata_ptr_list = amici_swig.runAmiciSimulations( - _get_ptr(solver), - edata_ptr_vector, - _get_ptr(model), - failfast, - num_threads, - ) + if isinstance(edata_list, ExpDataTimeseries): + edata_ptr_vector = edata_list.get_ptr_vector() + rdata_ptr_list = amici_swig.runAmiciSimulationsTimeseries( + _get_ptr(solver), + edata_ptr_vector, + _get_ptr(model), + failfast, + ) + else: + edata_ptr_vector = amici_swig.ExpDataPtrVector(edata_list) + rdata_ptr_list = amici_swig.runAmiciSimulations( + _get_ptr(solver), + edata_ptr_vector, + _get_ptr(model), + failfast, + num_threads, + ) for rdata in rdata_ptr_list: _log_simulation(rdata) if solver.getReturnDataReportingMode() == amici.RDataReporting.full: