diff --git a/rust/routee-compass-core/src/model/state/state_error.rs b/rust/routee-compass-core/src/model/state/state_error.rs index 3540fae6..8f39c607 100644 --- a/rust/routee-compass-core/src/model/state/state_error.rs +++ b/rust/routee-compass-core/src/model/state/state_error.rs @@ -12,7 +12,9 @@ pub enum StateError { UnknownStateVariableName(String, String), #[error("invalid state variable index {0}, should be in range [0, {1})")] InvalidStateVariableIndex(usize, usize), - #[error("expected feature to have {0} unit type but found {1}")] + #[error("expected feature to have type '{0}' but found '{1}'")] + UnexpectedFeatureType(String, String), + #[error("expected feature unit to be {0} but found {1}")] UnexpectedFeatureUnit(String, String), #[error("{0}")] BuildError(String), diff --git a/rust/routee-compass-core/src/model/state/state_feature.rs b/rust/routee-compass-core/src/model/state/state_feature.rs index 3ba281cf..dc3f5edc 100644 --- a/rust/routee-compass-core/src/model/state/state_feature.rs +++ b/rust/routee-compass-core/src/model/state/state_feature.rs @@ -29,7 +29,7 @@ use serde::{Deserialize, Serialize}; /// field names. see link for more information: /// https://serde.rs/enum-representations.html#untagged /// ``` -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Clone, Debug, Eq)] #[serde(untagged)] pub enum StateFeature { Distance { @@ -45,12 +45,71 @@ pub enum StateFeature { initial: unit::Energy, }, Custom { - name: String, + r#type: String, unit: String, format: CustomFeatureFormat, }, } +impl PartialEq for StateFeature { + /// tests equality based on the feature type. + /// + /// for distance|time|energy, it's fine to modify either the unit + /// or the initial value as this should not interfere with properly- + /// implemented TraversalModel, AccessModel, and FrontierModel instances. + /// + /// for custom features, we are stricter about this equality test. + /// for instance, we cannot allow a user to change the "meaning" of a + /// state of charge value, that it is a floating point value in the range [0.0, 1.0]. + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + StateFeature::Distance { + distance_unit: _, + initial: _, + }, + StateFeature::Distance { + distance_unit: _, + initial: _, + }, + ) => true, + ( + StateFeature::Time { + time_unit: _, + initial: _, + }, + StateFeature::Time { + time_unit: _, + initial: _, + }, + ) => true, + ( + StateFeature::Energy { + energy_unit: _, + initial: _, + }, + StateFeature::Energy { + energy_unit: _, + initial: _, + }, + ) => true, + ( + StateFeature::Custom { + r#type: a_name, + unit: a_unit, + format: _, + }, + StateFeature::Custom { + r#type: b_name, + unit: b_unit, + format: _, + }, + ) => a_name == b_name && a_unit == b_unit, + _ => false, + } + } +} + impl Display for StateFeature { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -65,7 +124,11 @@ impl Display for StateFeature { energy_unit, initial, } => write!(f, "unit: {}, initial: {}", energy_unit, initial), - StateFeature::Custom { name, unit, format } => { + StateFeature::Custom { + r#type: name, + unit, + format, + } => { write!(f, "name: {} unit: {}, repr: {}", name, unit, format) } } @@ -73,7 +136,7 @@ impl Display for StateFeature { } impl StateFeature { - pub fn get_feature_name(&self) -> String { + pub fn get_feature_type(&self) -> String { match self { StateFeature::Distance { distance_unit: _, @@ -88,10 +151,10 @@ impl StateFeature { initial: _, } => String::from("energy"), StateFeature::Custom { - name, + r#type, unit: _, format: _, - } => name.clone(), + } => r#type.clone(), } } @@ -110,7 +173,7 @@ impl StateFeature { initial: _, } => energy_unit.to_string(), StateFeature::Custom { - name: _, + r#type: _, unit, format: _, } => unit.clone(), @@ -124,7 +187,7 @@ impl StateFeature { pub fn get_feature_format(&self) -> CustomFeatureFormat { match self { StateFeature::Custom { - name: _, + r#type: _, unit: _, format, } => *format, @@ -133,7 +196,25 @@ impl StateFeature { } pub fn get_initial(&self) -> Result { - self.get_feature_format().initial() + match self { + StateFeature::Distance { + distance_unit: _, + initial, + } => Ok((*initial).into()), + StateFeature::Time { + time_unit: _, + initial, + } => Ok((*initial).into()), + StateFeature::Energy { + energy_unit: _, + initial, + } => Ok((*initial).into()), + StateFeature::Custom { + r#type: _, + unit: _, + format, + } => format.initial(), + } } pub fn get_distance_unit(&self) -> Result { @@ -144,7 +225,7 @@ impl StateFeature { } => Ok(*unit), _ => Err(StateError::UnexpectedFeatureUnit( String::from("distance"), - self.get_feature_name(), + self.get_feature_type(), )), } } @@ -157,7 +238,7 @@ impl StateFeature { } => Ok(*unit), _ => Err(StateError::UnexpectedFeatureUnit( String::from("time"), - self.get_feature_name(), + self.get_feature_type(), )), } } @@ -170,7 +251,7 @@ impl StateFeature { } => Ok(*energy_unit), _ => Err(StateError::UnexpectedFeatureUnit( String::from("energy"), - self.get_feature_name(), + self.get_feature_type(), )), } } @@ -178,13 +259,13 @@ impl StateFeature { pub fn get_custom_feature_format(&self) -> Result<&CustomFeatureFormat, StateError> { match self { StateFeature::Custom { - name: _, + r#type: _, unit: _, format, } => Ok(format), _ => Err(StateError::UnexpectedFeatureUnit( self.get_feature_unit_name(), - self.get_feature_name(), + self.get_feature_type(), )), } } diff --git a/rust/routee-compass-core/src/model/state/state_model.rs b/rust/routee-compass-core/src/model/state/state_model.rs index 50de946c..7c3bd9a3 100644 --- a/rust/routee-compass-core/src/model/state/state_model.rs +++ b/rust/routee-compass-core/src/model/state/state_model.rs @@ -35,8 +35,9 @@ impl StateModel { } /// extends a state model by adding additional key/value pairs to the model mapping. - /// in the case of name collision, a warning is logged to the user and the newer - /// variable is used. + /// in the case of name collision, we compare old and new state features at that name. + /// if the state feature has the same unit (tested by StateFeature::Eq), then it can + /// overwrite the existing. /// /// this method is used when state models are updated by the user query as Services /// become Models in the SearchApp. diff --git a/rust/routee-compass-powertrain/src/routee/vehicle/default/bev.rs b/rust/routee-compass-powertrain/src/routee/vehicle/default/bev.rs index ee46e2fa..781c6d15 100644 --- a/rust/routee-compass-powertrain/src/routee/vehicle/default/bev.rs +++ b/rust/routee-compass-powertrain/src/routee/vehicle/default/bev.rs @@ -63,7 +63,7 @@ impl VehicleType for BEV { ( String::from(BEV::SOC_FEATURE_NAME), StateFeature::Custom { - name: String::from("soc"), + r#type: String::from("soc"), unit: String::from("percent"), format: CustomFeatureFormat::FloatingPoint { initial: initial_soc.into(), diff --git a/rust/routee-compass-powertrain/src/routee/vehicle/default/phev.rs b/rust/routee-compass-powertrain/src/routee/vehicle/default/phev.rs index 9db7aa07..5c877c9c 100644 --- a/rust/routee-compass-powertrain/src/routee/vehicle/default/phev.rs +++ b/rust/routee-compass-powertrain/src/routee/vehicle/default/phev.rs @@ -74,7 +74,7 @@ impl VehicleType for PHEV { ( String::from(PHEV::SOC_FEATURE_NAME), StateFeature::Custom { - name: String::from("soc"), + r#type: String::from("soc"), unit: String::from("percent"), format: CustomFeatureFormat::FloatingPoint { initial: initial_soc.into(), diff --git a/rust/routee-compass/src/app/search/mod.rs b/rust/routee-compass/src/app/search/mod.rs index e2b62628..ee7ec11f 100644 --- a/rust/routee-compass/src/app/search/mod.rs +++ b/rust/routee-compass/src/app/search/mod.rs @@ -1,3 +1,4 @@ pub mod search_app; pub mod search_app_graph_ops; +pub mod search_app_ops; pub mod search_app_result; diff --git a/rust/routee-compass/src/app/search/search_app.rs b/rust/routee-compass/src/app/search/search_app.rs index 76a19f77..53f52684 100644 --- a/rust/routee-compass/src/app/search/search_app.rs +++ b/rust/routee-compass/src/app/search/search_app.rs @@ -1,4 +1,4 @@ -use super::search_app_result::SearchAppResult; +use super::{search_app_ops, search_app_result::SearchAppResult}; use crate::{ app::compass::{ compass_app_error::CompassAppError, @@ -170,9 +170,10 @@ impl SearchApp { let traversal_model = self.traversal_model_service.build(query)?; let access_model = self.access_model_service.build(query)?; - let mut added_features = traversal_model.state_features(); - added_features.extend(access_model.state_features()); - let state_model = Arc::new(self.state_model.extend(added_features)?); + let state_features = + search_app_ops::collect_features(query, traversal_model.clone(), access_model.clone())?; + let state_model_instance = self.state_model.extend(state_features)?; + let state_model = Arc::new(state_model_instance); let cost_model = self .cost_model_service diff --git a/rust/routee-compass/src/app/search/search_app_ops.rs b/rust/routee-compass/src/app/search/search_app_ops.rs new file mode 100644 index 00000000..d92821f1 --- /dev/null +++ b/rust/routee-compass/src/app/search/search_app_ops.rs @@ -0,0 +1,57 @@ +use std::{collections::HashMap, sync::Arc}; + +use itertools::Itertools; +use routee_compass_core::model::{ + access::access_model::AccessModel, + state::{state_error::StateError, state_feature::StateFeature}, + traversal::traversal_model::TraversalModel, +}; + +use crate::app::compass::config::config_json_extension::ConfigJsonExtensions; + +/// collects the state features to use in this search. the features are collected in +/// the following order: +/// 1. from the traversal model +/// 2. from the access model +/// 3. optionally from the query itself +/// using the order above, each new source optionally overwrites any existing feature +/// by name (tuple index 0) as long as they match in StateFeature::get_feature_name and +/// StateFeature::get_feature_unit_name. +pub fn collect_features( + query: &serde_json::Value, + traversal_model: Arc, + access_model: Arc, +) -> Result, StateError> { + // prepare the set of features for this state model + let model_features = traversal_model + .state_features() + .into_iter() + .chain(access_model.state_features()) + .collect::>(); + // build the state model. inject state features from the traversal and access models + // and then allow the user to optionally override any initial conditions for those + // state features. + let user_features_option: Option> = query + .get_config_serde_optional(&"state_features", &"query") + .map_err(|e| StateError::BuildError(e.to_string()))?; + let user_features = user_features_option + .unwrap_or_default() + .into_iter() + .map(|(name, feature)| match model_features.get(&name) { + None => { + let fnames = model_features.keys().join(","); + Err(StateError::UnknownStateVariableName(name, fnames)) + } + Some(existing) if existing.get_feature_type() != feature.get_feature_type() => { + Err(StateError::UnexpectedFeatureType( + existing.get_feature_type(), + feature.get_feature_type(), + )) + } + Some(_) => Ok((name, feature)), + }) + .collect::, _>>()?; + let mut added_features: Vec<(String, StateFeature)> = model_features.into_iter().collect_vec(); + added_features.extend(user_features); + Ok(added_features) +}