Skip to content

Commit

Permalink
Merge pull request #186 from NREL/rjf/query-time-state-features
Browse files Browse the repository at this point in the history
allow query to overwrite state features
  • Loading branch information
robfitzgerald authored Apr 11, 2024
2 parents 5ef2f58 + b7f9cb7 commit 29d944b
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 23 deletions.
4 changes: 3 additions & 1 deletion rust/routee-compass-core/src/model/state/state_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ pub enum StateError {
UnknownStateVariableName(String, String),
#[error("invalid state variable index {0}, should be in range [0, {1})")]
InvalidStateVariableIndex(usize, usize),
#[error("expected feature to have {0} unit type but found {1}")]
#[error("expected feature to have type '{0}' but found '{1}'")]
UnexpectedFeatureType(String, String),
#[error("expected feature unit to be {0} but found {1}")]
UnexpectedFeatureUnit(String, String),
#[error("{0}")]
BuildError(String),
Expand Down
109 changes: 95 additions & 14 deletions rust/routee-compass-core/src/model/state/state_feature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use serde::{Deserialize, Serialize};
/// field names. see link for more information:
/// https://serde.rs/enum-representations.html#untagged
/// ```
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Clone, Debug, Eq)]
#[serde(untagged)]
pub enum StateFeature {
Distance {
Expand All @@ -45,12 +45,71 @@ pub enum StateFeature {
initial: unit::Energy,
},
Custom {
name: String,
r#type: String,
unit: String,
format: CustomFeatureFormat,
},
}

impl PartialEq for StateFeature {
/// tests equality based on the feature type.
///
/// for distance|time|energy, it's fine to modify either the unit
/// or the initial value as this should not interfere with properly-
/// implemented TraversalModel, AccessModel, and FrontierModel instances.
///
/// for custom features, we are stricter about this equality test.
/// for instance, we cannot allow a user to change the "meaning" of a
/// state of charge value, that it is a floating point value in the range [0.0, 1.0].
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
StateFeature::Distance {
distance_unit: _,
initial: _,
},
StateFeature::Distance {
distance_unit: _,
initial: _,
},
) => true,
(
StateFeature::Time {
time_unit: _,
initial: _,
},
StateFeature::Time {
time_unit: _,
initial: _,
},
) => true,
(
StateFeature::Energy {
energy_unit: _,
initial: _,
},
StateFeature::Energy {
energy_unit: _,
initial: _,
},
) => true,
(
StateFeature::Custom {
r#type: a_name,
unit: a_unit,
format: _,
},
StateFeature::Custom {
r#type: b_name,
unit: b_unit,
format: _,
},
) => a_name == b_name && a_unit == b_unit,
_ => false,
}
}
}

impl Display for StateFeature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -65,15 +124,19 @@ impl Display for StateFeature {
energy_unit,
initial,
} => write!(f, "unit: {}, initial: {}", energy_unit, initial),
StateFeature::Custom { name, unit, format } => {
StateFeature::Custom {
r#type: name,
unit,
format,
} => {
write!(f, "name: {} unit: {}, repr: {}", name, unit, format)
}
}
}
}

impl StateFeature {
pub fn get_feature_name(&self) -> String {
pub fn get_feature_type(&self) -> String {
match self {
StateFeature::Distance {
distance_unit: _,
Expand All @@ -88,10 +151,10 @@ impl StateFeature {
initial: _,
} => String::from("energy"),
StateFeature::Custom {
name,
r#type,
unit: _,
format: _,
} => name.clone(),
} => r#type.clone(),
}
}

Expand All @@ -110,7 +173,7 @@ impl StateFeature {
initial: _,
} => energy_unit.to_string(),
StateFeature::Custom {
name: _,
r#type: _,
unit,
format: _,
} => unit.clone(),
Expand All @@ -124,7 +187,7 @@ impl StateFeature {
pub fn get_feature_format(&self) -> CustomFeatureFormat {
match self {
StateFeature::Custom {
name: _,
r#type: _,
unit: _,
format,
} => *format,
Expand All @@ -133,7 +196,25 @@ impl StateFeature {
}

pub fn get_initial(&self) -> Result<StateVar, StateError> {
self.get_feature_format().initial()
match self {
StateFeature::Distance {
distance_unit: _,
initial,
} => Ok((*initial).into()),
StateFeature::Time {
time_unit: _,
initial,
} => Ok((*initial).into()),
StateFeature::Energy {
energy_unit: _,
initial,
} => Ok((*initial).into()),
StateFeature::Custom {
r#type: _,
unit: _,
format,
} => format.initial(),
}
}

pub fn get_distance_unit(&self) -> Result<unit::DistanceUnit, StateError> {
Expand All @@ -144,7 +225,7 @@ impl StateFeature {
} => Ok(*unit),
_ => Err(StateError::UnexpectedFeatureUnit(
String::from("distance"),
self.get_feature_name(),
self.get_feature_type(),
)),
}
}
Expand All @@ -157,7 +238,7 @@ impl StateFeature {
} => Ok(*unit),
_ => Err(StateError::UnexpectedFeatureUnit(
String::from("time"),
self.get_feature_name(),
self.get_feature_type(),
)),
}
}
Expand All @@ -170,21 +251,21 @@ impl StateFeature {
} => Ok(*energy_unit),
_ => Err(StateError::UnexpectedFeatureUnit(
String::from("energy"),
self.get_feature_name(),
self.get_feature_type(),
)),
}
}

pub fn get_custom_feature_format(&self) -> Result<&CustomFeatureFormat, StateError> {
match self {
StateFeature::Custom {
name: _,
r#type: _,
unit: _,
format,
} => Ok(format),
_ => Err(StateError::UnexpectedFeatureUnit(
self.get_feature_unit_name(),
self.get_feature_name(),
self.get_feature_type(),
)),
}
}
Expand Down
5 changes: 3 additions & 2 deletions rust/routee-compass-core/src/model/state/state_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ impl StateModel {
}

/// extends a state model by adding additional key/value pairs to the model mapping.
/// in the case of name collision, a warning is logged to the user and the newer
/// variable is used.
/// in the case of name collision, we compare old and new state features at that name.
/// if the state feature has the same unit (tested by StateFeature::Eq), then it can
/// overwrite the existing.
///
/// this method is used when state models are updated by the user query as Services
/// become Models in the SearchApp.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl VehicleType for BEV {
(
String::from(BEV::SOC_FEATURE_NAME),
StateFeature::Custom {
name: String::from("soc"),
r#type: String::from("soc"),
unit: String::from("percent"),
format: CustomFeatureFormat::FloatingPoint {
initial: initial_soc.into(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl VehicleType for PHEV {
(
String::from(PHEV::SOC_FEATURE_NAME),
StateFeature::Custom {
name: String::from("soc"),
r#type: String::from("soc"),
unit: String::from("percent"),
format: CustomFeatureFormat::FloatingPoint {
initial: initial_soc.into(),
Expand Down
1 change: 1 addition & 0 deletions rust/routee-compass/src/app/search/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod search_app;
pub mod search_app_graph_ops;
pub mod search_app_ops;
pub mod search_app_result;
9 changes: 5 additions & 4 deletions rust/routee-compass/src/app/search/search_app.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::search_app_result::SearchAppResult;
use super::{search_app_ops, search_app_result::SearchAppResult};
use crate::{
app::compass::{
compass_app_error::CompassAppError,
Expand Down Expand Up @@ -170,9 +170,10 @@ impl SearchApp {
let traversal_model = self.traversal_model_service.build(query)?;
let access_model = self.access_model_service.build(query)?;

let mut added_features = traversal_model.state_features();
added_features.extend(access_model.state_features());
let state_model = Arc::new(self.state_model.extend(added_features)?);
let state_features =
search_app_ops::collect_features(query, traversal_model.clone(), access_model.clone())?;
let state_model_instance = self.state_model.extend(state_features)?;
let state_model = Arc::new(state_model_instance);

let cost_model = self
.cost_model_service
Expand Down
57 changes: 57 additions & 0 deletions rust/routee-compass/src/app/search/search_app_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::{collections::HashMap, sync::Arc};

use itertools::Itertools;
use routee_compass_core::model::{
access::access_model::AccessModel,
state::{state_error::StateError, state_feature::StateFeature},
traversal::traversal_model::TraversalModel,
};

use crate::app::compass::config::config_json_extension::ConfigJsonExtensions;

/// collects the state features to use in this search. the features are collected in
/// the following order:
/// 1. from the traversal model
/// 2. from the access model
/// 3. optionally from the query itself
/// using the order above, each new source optionally overwrites any existing feature
/// by name (tuple index 0) as long as they match in StateFeature::get_feature_name and
/// StateFeature::get_feature_unit_name.
pub fn collect_features(
query: &serde_json::Value,
traversal_model: Arc<dyn TraversalModel>,
access_model: Arc<dyn AccessModel>,
) -> Result<Vec<(String, StateFeature)>, StateError> {
// prepare the set of features for this state model
let model_features = traversal_model
.state_features()
.into_iter()
.chain(access_model.state_features())
.collect::<HashMap<_, _>>();
// build the state model. inject state features from the traversal and access models
// and then allow the user to optionally override any initial conditions for those
// state features.
let user_features_option: Option<HashMap<String, StateFeature>> = query
.get_config_serde_optional(&"state_features", &"query")
.map_err(|e| StateError::BuildError(e.to_string()))?;
let user_features = user_features_option
.unwrap_or_default()
.into_iter()
.map(|(name, feature)| match model_features.get(&name) {
None => {
let fnames = model_features.keys().join(",");
Err(StateError::UnknownStateVariableName(name, fnames))
}
Some(existing) if existing.get_feature_type() != feature.get_feature_type() => {
Err(StateError::UnexpectedFeatureType(
existing.get_feature_type(),
feature.get_feature_type(),
))
}
Some(_) => Ok((name, feature)),
})
.collect::<Result<Vec<_>, _>>()?;
let mut added_features: Vec<(String, StateFeature)> = model_features.into_iter().collect_vec();
added_features.extend(user_features);
Ok(added_features)
}

0 comments on commit 29d944b

Please sign in to comment.