Skip to content

Commit

Permalink
zip(... , strict=True) (#2412)
Browse files Browse the repository at this point in the history
Closes #2410.
  • Loading branch information
dweindl authored Apr 23, 2024
1 parent b2aeae0 commit a1d1451
Show file tree
Hide file tree
Showing 21 changed files with 100 additions and 51 deletions.
8 changes: 4 additions & 4 deletions python/sdist/amici/conserved_quantities_demartino.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def log(*args, **kwargs):
# print all conserved quantities
if verbose:
for i, (coefficients, engaged_species_idxs) in enumerate(
zip(species_coefficients, species_indices)
zip(species_coefficients, species_indices, strict=True)
):
if not engaged_species_idxs:
continue
Expand All @@ -148,7 +148,7 @@ def log(*args, **kwargs):
"species:"
)
for species_idx, coefficient in zip(
engaged_species_idxs, coefficients
engaged_species_idxs, coefficients, strict=True
):
name = (
species_names[species_idx]
Expand Down Expand Up @@ -957,12 +957,12 @@ def _reduce(
k2 = order[j]
column: list[float] = [0] * num_species
for species_idx, coefficient in zip(
cls_species_idxs[k1], cls_coefficients[k1]
cls_species_idxs[k1], cls_coefficients[k1], strict=True
):
column[species_idx] = coefficient
ok1 = True
for species_idx, coefficient in zip(
cls_species_idxs[k2], cls_coefficients[k2]
cls_species_idxs[k2], cls_coefficients[k2], strict=True
):
column[species_idx] -= coefficient
if column[species_idx] < -_MIN:
Expand Down
10 changes: 7 additions & 3 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def format_regular_line(symbol, math, index):
# we need toposort to handle the dependencies of extracted
# subexpressions
expr_dict = dict(
itertools.chain(zip(symbols, reduced_exprs), replacements)
itertools.chain(
zip(symbols, reduced_exprs, strict=True), replacements
)
)
sorted_symbols = toposort(
{
Expand All @@ -192,7 +194,7 @@ def format_regular_line(symbol, math, index):
}
)
symbol_to_idx = {
sym: idx for idx, sym in zip(indices, symbols)
sym: idx for idx, sym in zip(indices, symbols, strict=True)
}

def format_line(symbol: sp.Symbol):
Expand All @@ -217,7 +219,9 @@ def format_line(symbol: sp.Symbol):

return [
format_regular_line(sym, math, index)
for index, sym, math in zip(indices, symbols, equations)
for index, sym, math in zip(
indices, symbols, equations, strict=True
)
if math not in [0, 0.0]
]

Expand Down
6 changes: 4 additions & 2 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,9 @@ def _get_function_body(
for ipar in range(self.model.num_par()):
expressions = []
for index, formula in zip(
self.model._x0_fixedParameters_idx, equations[:, ipar]
self.model._x0_fixedParameters_idx,
equations[:, ipar],
strict=True,
):
if not formula.is_zero:
expressions.extend(
Expand All @@ -735,7 +737,7 @@ def _get_function_body(

elif function == "x0_fixedParameters":
for index, formula in zip(
self.model._x0_fixedParameters_idx, equations
self.model._x0_fixedParameters_idx, equations, strict=True
):
lines.append(
f" if(std::find(reinitialization_state_idxs.cbegin(), "
Expand Down
23 changes: 17 additions & 6 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,9 @@ def states(self) -> list[State]:
def _process_sbml_rate_of(self) -> None:
"""Substitute any SBML-rateOf constructs in the model equations"""
rate_of_func = sp.core.function.UndefinedFunction("rateOf")
species_sym_to_xdot = dict(zip(self.sym("x"), self.sym("xdot")))
species_sym_to_xdot = dict(
zip(self.sym("x"), self.sym("xdot"), strict=True)
)
species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))}

def get_rate(symbol: sp.Symbol):
Expand Down Expand Up @@ -427,7 +429,7 @@ def get_rate(symbol: sp.Symbol):
if made_substitutions:
# substitute in topological order
subs = toposort_symbols(
dict(zip(self.sym("xdot"), self.eq("xdot")))
dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True))
)
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)

Expand Down Expand Up @@ -469,6 +471,7 @@ def get_rate(symbol: sp.Symbol):
zip(
self.sym("w")[self.num_cons_law() :, :],
self.eq("w")[self.num_cons_law() :, :],
strict=True,
)
)
)
Expand All @@ -477,6 +480,7 @@ def get_rate(symbol: sp.Symbol):
zip(
self.sym("w")[: self.num_cons_law(), :],
self.eq("w")[: self.num_cons_law(), :],
strict=True,
)
)
| w_sorted
Expand Down Expand Up @@ -1027,7 +1031,9 @@ def static_indices(self, name: str) -> list[int]:
# of non-zeros entries of the sparse matrix
self._static_indices[name] = [
i
for i, (expr, row_idx) in enumerate(zip(sparseeq, rowvals))
for i, (expr, row_idx) in enumerate(
zip(sparseeq, rowvals, strict=True)
)
# derivative of a static expression is static
if row_idx in static_indices_w
# constant expressions
Expand Down Expand Up @@ -1396,6 +1402,8 @@ def _compute_equation(self, name: str) -> None:
if not s.has_conservation_law()
),
self.sym("dx"),
# dx contains extra elements for algebraic states
strict=False,
)
]
+ [eq.get_val() for eq in self._algebraic_equations]
Expand Down Expand Up @@ -1556,7 +1564,9 @@ def _compute_equation(self, name: str) -> None:
# backsubstitution of optimized right-hand side terms into RHS
# calling subs() is costly. Due to looping over events though, the
# following lines are only evaluated if a model has events
w_sorted = toposort_symbols(dict(zip(self.sym("w"), self.eq("w"))))
w_sorted = toposort_symbols(
dict(zip(self.sym("w"), self.eq("w"), strict=True))
)
tmp_xdot = smart_subs_dict(self.eq("xdot"), w_sorted)
self._eqs[name] = self.eq("drootdt")
if self.num_states_solver():
Expand Down Expand Up @@ -1586,7 +1596,7 @@ def _compute_equation(self, name: str) -> None:
for event_obs in self._event_observables
]
for (iz, ie), event_obs in zip(
enumerate(z2event), self._event_observables
enumerate(z2event), self._event_observables, strict=True
):
event_observables[ie - 1][iz] = event_obs.get_val()

Expand Down Expand Up @@ -1727,7 +1737,7 @@ def _compute_equation(self, name: str) -> None:
syms_x = self.sym("x")
syms_yz = self.sym(name.removeprefix("sigma"))
xs_in_sigma = {}
for sym_yz, eq_yz in zip(syms_yz, self._eqs[name]):
for sym_yz, eq_yz in zip(syms_yz, self._eqs[name], strict=True):
yz_free_syms = eq_yz.free_symbols
if tmp := {x for x in syms_x if x in yz_free_syms}:
xs_in_sigma[sym_yz] = tmp
Expand Down Expand Up @@ -2229,6 +2239,7 @@ def _collect_heaviside_roots(
zip(
[expr.get_id() for expr in self._expressions],
[expr.get_val() for expr in self._expressions],
strict=True,
)
)
)
Expand Down
6 changes: 3 additions & 3 deletions python/sdist/amici/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def getSimulationObservablesAsDataFrame(

# aggregate records
dicts = []
for edata, rdata in zip(edata_list, rdata_list):
for edata, rdata in zip(edata_list, rdata_list, strict=True):
for i_time, timepoint in enumerate(rdata["t"]):
datadict = {
"time": timepoint,
Expand Down Expand Up @@ -212,7 +212,7 @@ def getSimulationStatesAsDataFrame(

# aggregate records
dicts = []
for edata, rdata in zip(edata_list, rdata_list):
for edata, rdata in zip(edata_list, rdata_list, strict=True):
for i_time, timepoint in enumerate(rdata["t"]):
datadict = {
"time": timepoint,
Expand Down Expand Up @@ -268,7 +268,7 @@ def get_expressions_as_dataframe(

# aggregate records
dicts = []
for edata, rdata in zip(edata_list, rdata_list):
for edata, rdata in zip(edata_list, rdata_list, strict=True):
for i_time, timepoint in enumerate(rdata["t"]):
datadict = {
"time": timepoint,
Expand Down
4 changes: 3 additions & 1 deletion python/sdist/amici/petab/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def fill_in_parameters(
RuntimeWarning,
)

for edata, mapping_for_condition in zip(edatas, parameter_mapping):
for edata, mapping_for_condition in zip(
edatas, parameter_mapping, strict=True
):
fill_in_parameters_for_condition(
edata,
problem_parameters,
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/petab/parameter_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def create_parameter_mapping(

parameter_mapping = ParameterMapping()
for (_, condition), prelim_mapping_for_condition in zip(
simulation_conditions.iterrows(), prelim_parameter_mapping
simulation_conditions.iterrows(), prelim_parameter_mapping, strict=True
):
mapping_for_condition = create_parameter_mapping_for_condition(
prelim_mapping_for_condition, condition, petab_problem, amici_model
Expand Down
1 change: 1 addition & 0 deletions python/sdist/amici/petab/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _add_observation_model(
petab_problem.observable_df.index,
petab_problem.observable_df[OBSERVABLE_FORMULA],
petab_problem.observable_df[NOISE_FORMULA],
strict=True,
):
obs_symbol = sp.sympify(observable_formula, locals=local_syms)
if observable_id in pysb_model.expressions.keys():
Expand Down
7 changes: 5 additions & 2 deletions python/sdist/amici/petab/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def simulate_petab(
zip(
petab_problem.x_ids,
petab_problem.x_nominal_scaled,
strict=True,
)
)
# depending on `fill_fixed_parameters` for parameter mapping, the
Expand Down Expand Up @@ -311,7 +312,7 @@ def aggregate_sllh(
)

for condition_parameter_mapping, edata, rdata in zip(
parameter_mapping, edatas, rdatas
parameter_mapping, edatas, rdatas, strict=True
):
for sllh_parameter_index, condition_parameter_sllh in enumerate(
rdata.sllh
Expand Down Expand Up @@ -433,7 +434,9 @@ def rdatas_to_measurement_df(
observable_ids = model.getObservableIds()
rows = []
# iterate over conditions
for (_, condition), rdata in zip(simulation_conditions.iterrows(), rdatas):
for (_, condition), rdata in zip(
simulation_conditions.iterrows(), rdatas, strict=True
):
# current simulation matrix
y = rdata.y
# time array used in rdata
Expand Down
4 changes: 2 additions & 2 deletions python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def plot_state_trajectories(
else:
labels = np.asarray(rdata.ptr.state_ids)[list(state_indices)]

for ix, label in zip(state_indices, labels):
for ix, label in zip(state_indices, labels, strict=True):
ax.plot(rdata["t"], rdata["x"][:, ix], marker=marker, label=label)

ax.set_xlabel("$t$")
Expand Down Expand Up @@ -131,7 +131,7 @@ def plot_observable_trajectories(
else:
labels = np.asarray(rdata.ptr.observable_ids)[list(observable_indices)]

for iy, label in zip(observable_indices, labels):
for iy, label in zip(observable_indices, labels, strict=True):
(l,) = ax.plot(
rdata["t"], rdata["y"][:, iy], marker=marker, label=label
)
Expand Down
6 changes: 5 additions & 1 deletion python/sdist/amici/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,11 @@ def _add_expression(
cost_fun_expr = sp.sympify(
cost_fun_str,
locals=dict(
zip(_get_str_symbol_identifiers(name), (y, my, sigma))
zip(
_get_str_symbol_identifiers(name),
(y, my, sigma),
strict=True,
)
),
)
ode_model.add_component(
Expand Down
28 changes: 22 additions & 6 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,18 @@ def _build_ode_model(
dxdt = smart_multiply(
self.stoichiometric_matrix, MutableDenseMatrix(fluxes)
)
# dxdt has algebraic states at the end
assert dxdt.shape[0] - len(self.symbols[SymbolId.SPECIES]) == len(
self.symbols.get(SymbolId.ALGEBRAIC_STATE, [])
), (
self.symbols[SymbolId.SPECIES],
dxdt,
self.symbols[SymbolId.SPECIES],
)

# correct time derivatives for compartment changes
for ix, ((species_id, species), formula) in enumerate(
zip(self.symbols[SymbolId.SPECIES].items(), dxdt)
zip(self.symbols[SymbolId.SPECIES].items(), dxdt, strict=False)
):
# rate rules and amount species don't need to be updated
if "dt" in species:
Expand Down Expand Up @@ -604,7 +613,7 @@ def _build_ode_model(

# add fluxes as expressions, this needs to happen after base
# expressions from symbols have been parsed
for flux_id, flux in zip(fluxes, self.flux_vector):
for flux_id, flux in zip(fluxes, self.flux_vector, strict=True):
# replace splines inside fluxes
flux = flux.subs(spline_subs)
ode_model.add_component(
Expand Down Expand Up @@ -981,7 +990,9 @@ def _process_species_initial(self):
self.symbols[SymbolId.SPECIES], "init"
)
for species, rateof_dummies in zip(
self.symbols[SymbolId.SPECIES].values(), all_rateof_dummies
self.symbols[SymbolId.SPECIES].values(),
all_rateof_dummies,
strict=True,
):
species["init"] = _dummy_to_rateof(
smart_subs_dict(species["init"], sorted_species, "init"),
Expand Down Expand Up @@ -1945,6 +1956,7 @@ def _process_log_likelihood(
for (obs_id, obs), (sigma_id, sigma) in zip(
self.symbols[obs_symbol].items(),
self.symbols[sigma_symbol].items(),
strict=True,
):
symbol = symbol_with_assumptions(f"J{obs_id}")
dist = noise_distributions.get(str(obs_id), "normal")
Expand All @@ -1955,6 +1967,7 @@ def _process_log_likelihood(
zip(
_get_str_symbol_identifiers(obs_id),
(obs_id, obs["measurement_symbol"], sigma_id),
strict=True,
)
),
)
Expand Down Expand Up @@ -2173,9 +2186,9 @@ def _get_conservation_laws_demartino(
len(cls_coefficients), len(ode_model._differential_states)
)
for i_cl, (cl, coefficients) in enumerate(
zip(cls_state_idxs, cls_coefficients)
zip(cls_state_idxs, cls_coefficients, strict=True)
):
for i, c in zip(cl, coefficients):
for i, c in zip(cl, coefficients, strict=True):
A[i_cl, i] = sp.Rational(c)
rref, pivots = A.rref()

Expand Down Expand Up @@ -2319,7 +2332,10 @@ def _add_conservation_for_non_constant_species(
"coefficients": {
state_id: coeff * compartment
for state_id, coeff, compartment in zip(
state_ids, coefficients, compartment_sizes
state_ids,
coefficients,
compartment_sizes,
strict=True,
)
},
}
Expand Down
8 changes: 5 additions & 3 deletions python/sdist/amici/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,9 @@ def check_if_valid(self, importer: sbml_import.SbmlImporter) -> None:
importer.symbols[SymbolId.FIXED_PARAMETER][fp]["value"]
for fp in fixed_parameters
]
subs = dict(zip(fixed_parameters, fixed_parameters_values))
subs = dict(
zip(fixed_parameters, fixed_parameters_values, strict=True)
)
nodes_values = [sp.simplify(x.subs(subs)) for x in self.nodes]
for x in nodes_values:
assert x.is_Number
Expand Down Expand Up @@ -1093,7 +1095,7 @@ def add_to_sbml_model(
# It makes no sense to give a single nominal value:
# grid values must all be different
raise TypeError("x_nominal must be a Sequence!")
for _x, _val in zip(self.nodes, x_nominal):
for _x, _val in zip(self.nodes, x_nominal, strict=True):
if _x.is_Symbol and not model.getParameter(_x.name):
add_parameter(
model, _x.name, value=_val, units=x_units
Expand All @@ -1116,7 +1118,7 @@ def add_to_sbml_model(
else:
y_constant = len(self.values_at_nodes) * [y_constant]
for _y, _val, _const in zip(
self.values_at_nodes, y_nominal, y_constant
self.values_at_nodes, y_nominal, y_constant, strict=True
):
if _y.is_Symbol and not model.getParameter(_y.name):
add_parameter(
Expand Down
Loading

0 comments on commit a1d1451

Please sign in to comment.