Skip to content

Commit

Permalink
Accept *Model instances where an equivalent dict or list is acc…
Browse files Browse the repository at this point in the history
…epted.
  • Loading branch information
daemontus committed Oct 26, 2024
1 parent 847f043 commit fc202df
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 93 deletions.
14 changes: 7 additions & 7 deletions biodivine_aeon/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1364,8 +1364,8 @@ class AsynchronousGraph:
@overload
def transfer_from(self, set: ColoredVertexSet, original_ctx: AsynchronousGraph) -> ColoredVertexSet: ...
def transfer_from(self, set, original_ctx): ...
def mk_subspace(self, subspace: Union[Mapping[VariableId, BoolType], Mapping[str, BoolType]]) -> ColoredVertexSet: ...
def mk_subspace_vertices(self, subspace: Union[Mapping[VariableId, BoolType], Mapping[str, BoolType]]) -> VertexSet: ...
def mk_subspace(self, subspace: Union[Mapping[VariableId, BoolType], Mapping[str, BoolType], VertexModel]) -> ColoredVertexSet: ...
def mk_subspace_vertices(self, subspace: Union[Mapping[VariableId, BoolType], Mapping[str, BoolType], VertexModel]) -> VertexSet: ...
def mk_update_function(self, variable: VariableIdType) -> Bdd: ...
def post(self, set: ColoredVertexSet) -> ColoredVertexSet: ...
#def post_out(self, set: ColoredVertexSet) -> ColoredVertexSet: ...
Expand Down Expand Up @@ -1748,9 +1748,9 @@ class ColoredPerturbationSet:
def intersect_perturbations(self, vertices: PerturbationSet) -> ColoredPerturbationSet: ...
def minus_colors(self, colors: ColorSet) -> ColoredPerturbationSet: ...
def minus_perturbations(self, vertices: PerturbationSet) -> ColoredPerturbationSet: ...
def select_perturbation(self, perturbation: Mapping[VariableIdType, Optional[bool]]) -> ColorSet: ...
def select_perturbations(self, perturbations: Mapping[VariableIdType, Optional[bool]]) -> ColoredPerturbationSet: ...
def perturbation_robustness(self, perturbation: Mapping[VariableIdType, Optional[bool]]) -> float: ...
def select_perturbation(self, perturbation: Union[Mapping[VariableIdType, Optional[bool]], PerturbationModel]) -> ColorSet: ...
def select_perturbations(self, perturbations: Union[Mapping[VariableIdType, Optional[bool]], PerturbationModel]) -> ColoredPerturbationSet: ...
def perturbation_robustness(self, perturbation: Union[Mapping[VariableIdType, Optional[bool]], PerturbationModel]) -> float: ...
def select_by_size(self, size: int, up_to: bool) -> ColoredPerturbationSet: ...
def select_by_robustness(self, threshold: float, result_limit: Optional[int] = None) -> list[tuple[PerturbationModel, float, ColorSet]]: ...
def pick_singleton(self) -> ColoredPerturbationSet: ...
Expand Down Expand Up @@ -1785,8 +1785,8 @@ class AsynchronousPerturbationGraph(AsynchronousGraph):
def mk_empty_colored_perturbations(self) -> ColoredPerturbationSet: ...
def mk_perturbable_unit_colors(self) -> ColorSet: ...
def mk_perturbable_unit_colored_vertices(self) -> ColoredVertexSet: ...
def mk_perturbation(self, perturbation: Mapping[VariableIdType, Optional[bool]]) -> PerturbationSet: ...
def mk_perturbations(self, perturbations: Mapping[VariableIdType, Optional[bool]]) -> PerturbationSet: ...
def mk_perturbation(self, perturbation: Union[Mapping[VariableIdType, Optional[bool]], PerturbationModel]) -> PerturbationSet: ...
def mk_perturbations(self, perturbations: Union[Mapping[VariableIdType, Optional[bool]], PerturbationModel]) -> PerturbationSet: ...
def mk_perturbations_with_size(self, size: int, up_to: bool) -> PerturbationSet: ...
def colored_robustness(self, set: ColorSet) -> float: ...

Expand Down
10 changes: 9 additions & 1 deletion src/bindings/lib_param_bn/symbolic/asynchronous_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::bindings::lib_param_bn::symbolic::set_vertex::VertexSet;
use crate::bindings::lib_param_bn::symbolic::symbolic_context::SymbolicContext;

use crate::bindings::lib_hctl_model_checker::hctl_formula::HctlFormula;
use crate::bindings::lib_param_bn::symbolic::model_vertex::VertexModel;
use crate::bindings::lib_param_bn::variable_id::VariableId;
use crate::bindings::lib_param_bn::NetworkVariableContext;
use crate::pyo3_utils::BoolLikeValue;
Expand Down Expand Up @@ -751,8 +752,15 @@ impl AsynchronousGraph {
result.push((k, v.bool()));
}
return Ok(result);
} else if let Ok(model) = subspace.downcast::<VertexModel>() {
return Ok(model
.get()
.items()
.into_iter()
.map(|(a, b)| (a.into(), b))
.collect());
}
throw_type_error("Expected a dictionary of `VariableIdType` keys and `BoolType` values.")
throw_type_error("Expected a dictionary of `VariableIdType` keys and `BoolType` values or a `VertexModel`.")
}

pub fn wrap_native(py: Python, stg: SymbolicAsyncGraph) -> PyResult<AsynchronousGraph> {
Expand Down
6 changes: 0 additions & 6 deletions src/bindings/lib_param_bn/symbolic/model_color.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,6 @@ impl ColorModel {
assert_args_is_none(args)?;
let mut bn = network.borrow(py).as_native().clone();

// This is the expected number of parameters after the ones available in this model
// are instantiated.
let expected = (bn.num_parameters() + bn.num_implicit_parameters())
- (self.retained_implicit.len() + self.retained_explicit.len());

for var in bn.variables() {
let function = if let Some(function) = bn.get_update_function(var) {
self.instantiate_fn_update(function)?
Expand All @@ -278,7 +273,6 @@ impl ColorModel {
}

let bn = bn.prune_unused_parameters();
assert_eq!(bn.num_parameters() + bn.num_implicit_parameters(), expected);

let bn = if infer_regulations.unwrap_or_default() {
bn.infer_valid_graph().map_err(runtime_error)?
Expand Down
82 changes: 55 additions & 27 deletions src/bindings/pbn_control/asynchronous_perturbation_graph.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use biodivine_lib_param_bn::biodivine_std::bitvector::ArrayBitVector;
use biodivine_lib_param_bn::symbolic_async_graph::{GraphColoredVertices, GraphColors};
Expand All @@ -19,9 +19,9 @@ use crate::bindings::lib_param_bn::variable_id::VariableId;
use crate::bindings::lib_param_bn::NetworkVariableContext;
use crate::bindings::pbn_control::control::sanitize_control_map;
use crate::bindings::pbn_control::set_colored_perturbation::ColoredPerturbationSet;
use crate::bindings::pbn_control::PerturbationSet;
use crate::bindings::pbn_control::{PerturbationModel, PerturbationSet};
use crate::pyo3_utils::BoolLikeValue;
use crate::{throw_runtime_error, AsNative};
use crate::{throw_runtime_error, throw_type_error, AsNative};

/// An extension of `AsynchronousGraph` that admits various variable perturbations through
/// additional colors/parameters. Such graph can then be analyzed to extract control strategies
Expand Down Expand Up @@ -384,7 +384,7 @@ impl AsynchronousPerturbationGraph {
pub fn mk_perturbation(
_self: Py<Self>,
py: Python,
perturbation: &Bound<'_, PyDict>,
perturbation: &Bound<'_, PyAny>,
) -> PyResult<PerturbationSet> {
let self_borrow = _self.borrow(py);
let parent = self_borrow.as_ref();
Expand All @@ -394,24 +394,20 @@ impl AsynchronousPerturbationGraph {
.get()
.as_native()
.get_perturbation_bdd_mapping(perturbable);

// Init the partial valuation such that everything is unperturbed initially.
for bdd_var in map.values() {
partial_valuation.set_value(*bdd_var, false);
}

let perturbation = Self::resolve_perturbation(&self_borrow, perturbation)?;

// Read data from the dictionary.
for (k, v) in perturbation {
let k_var = parent.resolve_network_variable(&k)?;
let s_var = parent
.as_native()
.symbolic_context()
.get_state_variable(k_var);
let Some(p_var) = map.get(&k_var).cloned() else {
return throw_runtime_error(format!("Variable {k_var} cannot be perturbed."));
};

let val = v.extract::<Option<bool>>()?;
match val {
let s_var = parent.as_native().symbolic_context().get_state_variable(k);
let p_var = *map.get(&k).unwrap();

match v {
None => partial_valuation.set_value(p_var, false),
Some(val) => {
partial_valuation.set_value(p_var, true);
Expand Down Expand Up @@ -440,7 +436,7 @@ impl AsynchronousPerturbationGraph {
pub fn mk_perturbations(
_self: Py<Self>,
py: Python,
perturbations: &Bound<'_, PyDict>,
perturbations: &Bound<'_, PyAny>,
) -> PyResult<PerturbationSet> {
let self_borrow = _self.borrow(py);
let parent = self_borrow.as_ref();
Expand All @@ -450,18 +446,12 @@ impl AsynchronousPerturbationGraph {
.get()
.as_native()
.get_perturbation_bdd_mapping(perturbable);
let perturbations = Self::resolve_perturbation(&self_borrow, perturbations)?;
for (k, v) in perturbations {
let k_var = parent.resolve_network_variable(&k)?;
let s_var = parent
.as_native()
.symbolic_context()
.get_state_variable(k_var);
let Some(p_var) = map.get(&k_var).cloned() else {
return throw_runtime_error(format!("Variable {k_var} cannot be perturbed."));
};

let val = v.extract::<Option<bool>>()?;
match val {
let s_var = parent.as_native().symbolic_context().get_state_variable(k);
let p_var = *map.get(&k).unwrap();

match v {
None => partial_valuation.set_value(p_var, false),
Some(val) => {
partial_valuation.set_value(p_var, true);
Expand Down Expand Up @@ -621,4 +611,42 @@ impl AsynchronousPerturbationGraph {
let set = GraphColoredVertices::new(bdd, ctx.get().as_native());
ColoredVertexSet::mk_native(self_ref.as_ref().symbolic_context(), set)
}

/// Returns a list of perturbed variables together with their values, or error if the
/// variables are invalid (e.g. not perturbable). If a variable is not present, it is not
/// returned. It is up to the caller to interpret this correctly.
pub fn resolve_perturbation(
_self: &PyRef<'_, AsynchronousPerturbationGraph>,
value: &Bound<'_, PyAny>,
) -> PyResult<HashMap<biodivine_lib_param_bn::VariableId, Option<bool>>> {
let parent_ref = _self.as_ref();
let perturbable: HashSet<biodivine_lib_param_bn::VariableId> =
HashSet::from_iter(_self.as_native().perturbable_variables().clone());
let mut result = HashMap::new();
if let Ok(dict) = value.downcast::<PyDict>() {
for (k, v) in dict {
let k_var = parent_ref.resolve_network_variable(&k)?;

if !perturbable.contains(&k_var) {
return throw_runtime_error(format!("Variable {k_var} cannot be perturbed."));
};

let val = v.extract::<Option<bool>>()?;
result.insert(k_var, val);
}
} else if let Ok(model) = value.downcast::<PerturbationModel>() {
for (k, v) in model.get().items() {
let k_var: biodivine_lib_param_bn::VariableId = k.into();

if !perturbable.contains(&k_var) {
return throw_runtime_error(format!("Variable {k_var} cannot be perturbed."));
};

result.insert(k_var, v);
}
} else {
return throw_type_error("Expected a dictionary of `VariableIdType` keys and `BoolType | None` values, or a `PerturbationModel`.");
}
Ok(result)
}
}
48 changes: 23 additions & 25 deletions src/bindings/pbn_control/set_colored_perturbation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use either::Either;
use num_bigint::BigInt;
use pyo3::basic::CompareOp;
use pyo3::prelude::PyListMethods;
use pyo3::types::{PyAnyMethods, PyDict, PyList};
use pyo3::types::PyList;
use pyo3::{pyclass, pymethods, Bound, IntoPy, Py, PyAny, PyResult, Python};

use crate::bindings::lib_bdd::bdd::Bdd;
Expand Down Expand Up @@ -267,31 +267,30 @@ impl ColoredPerturbationSet {
///
/// To specify that a variable should be unperturbed, use `"var": None`. Any variable that
/// should remain unrestricted should be completely omitted from the `perturbations`
/// dictionary. This is similar to `AsynchronousPerturbationGraph.mk_perturbations`.
fn select_perturbations(
/// dictionary. This is similar to `AsynchronousPerturbationGraph.mk_perturbations`. If a
/// `PerturbationModel` is given, only values that are omitted through projection will be
/// considered as unrestricted.
pub fn select_perturbations(
&self,
py: Python,
perturbations: &Bound<'_, PyDict>,
perturbations: &Bound<'_, PyAny>,
) -> PyResult<ColoredPerturbationSet> {
let borrowed = self.ctx.borrow(py);
let parent = borrowed.as_ref();
let native_graph = self.ctx.get().as_native();
let mapping =
native_graph.get_perturbation_bdd_mapping(native_graph.perturbable_variables());

let perturbations =
AsynchronousPerturbationGraph::resolve_perturbation(&borrowed, perturbations)?;
let mut selection = biodivine_lib_bdd::BddPartialValuation::empty();

// Go through the given perturbation and mark everything that should be perturbed.
for (k, v) in perturbations {
let k_var = parent.resolve_network_variable(&k)?;
let s_var = parent
.as_native()
.symbolic_context()
.get_state_variable(k_var);
let Some(p_var) = mapping.get(&k_var).cloned() else {
return throw_runtime_error(format!("Variable {k_var} cannot be perturbed."));
};
let s_var = parent.as_native().symbolic_context().get_state_variable(k);
let p_var = *mapping.get(&k).unwrap();

match v.extract::<Option<bool>>()? {
match v {
Some(val) => {
selection.set_value(p_var, true);
selection.set_value(s_var, val);
Expand All @@ -312,18 +311,22 @@ impl ColoredPerturbationSet {
///
/// *Note that here, we assume that the dictionary represents a single perturbation. Therefore,
/// any missing perturbable variables are treated as unperturbed.* This is the same behavior
/// as in `AsynchronousPerturbationGraph.mk_perturbation`.
/// as in `AsynchronousPerturbationGraph.mk_perturbation`. Similarly, if a `PerturbationModel`
/// is provided with some values eliminated through projection, these are assumed to be
/// unperturbed.
///
fn select_perturbation(
&self,
py: Python,
perturbation: &Bound<'_, PyDict>,
perturbation: &Bound<'_, PyAny>,
) -> PyResult<ColorSet> {
let borrowed = self.ctx.borrow(py);
let parent = borrowed.as_ref();
let native_graph = self.ctx.get().as_native();
let mapping =
native_graph.get_perturbation_bdd_mapping(native_graph.perturbable_variables());
let perturbation =
AsynchronousPerturbationGraph::resolve_perturbation(&borrowed, perturbation)?;
let mut restriction = biodivine_lib_bdd::BddPartialValuation::empty();

// Initially, set all variables to unperturbed.
Expand All @@ -333,16 +336,11 @@ impl ColoredPerturbationSet {

// Then go through the given perturbation and mark everything that should be perturbed.
for (k, v) in perturbation {
let k_var = parent.resolve_network_variable(&k)?;
let s_var = parent
.as_native()
.symbolic_context()
.get_state_variable(k_var);
let Some(p_var) = mapping.get(&k_var).cloned() else {
return throw_runtime_error(format!("Variable {k_var} cannot be perturbed."));
};
let s_var = parent.as_native().symbolic_context().get_state_variable(k);
// Unwrap safe because of `resolve_perturbation`.
let p_var = *mapping.get(&k).unwrap();

if let Some(val) = v.extract::<Option<bool>>()? {
if let Some(val) = v {
restriction.set_value(p_var, true);
restriction.set_value(s_var, val);
}
Expand Down Expand Up @@ -373,7 +371,7 @@ impl ColoredPerturbationSet {
fn perturbation_robustness(
&self,
py: Python,
perturbation: &Bound<'_, PyDict>,
perturbation: &Bound<'_, PyAny>,
) -> PyResult<f64> {
let colors = self.select_perturbation(py, perturbation)?;
AsynchronousPerturbationGraph::colored_robustness(self.ctx.bind(py).clone(), &colors)
Expand Down
Loading

0 comments on commit fc202df

Please sign in to comment.