From 396a71cc0988c85d39553d35df448b20e2fbafa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Willy=20Rom=C3=A3o?= Date: Sun, 21 Jan 2024 00:16:31 +0000 Subject: [PATCH] refactor(context): reimplement context as an kv store --- mfm_core/src/contexts/mod.rs | 15 +- mfm_core/src/states/mod.rs | 31 ++- mfm_machine/src/state/context.rs | 86 ++++---- mfm_machine/src/state_machine/mod.rs | 150 ++++++++------ mfm_machine/src/state_machine/tracker.rs | 91 ++++----- mfm_machine/tests/default_impls.rs | 192 +++++++++--------- mfm_machine/tests/n_states_with_n_ctxs.rs | 18 +- mfm_machine/tests/public_api_test.rs | 26 ++- .../tests/retry_workflow_state_machine.rs | 13 +- 9 files changed, 335 insertions(+), 287 deletions(-) diff --git a/mfm_core/src/contexts/mod.rs b/mfm_core/src/contexts/mod.rs index f825b65..318cecd 100644 --- a/mfm_core/src/contexts/mod.rs +++ b/mfm_core/src/contexts/mod.rs @@ -1,8 +1,5 @@ use crate::config::Config; -use anyhow::{anyhow, Error}; -use mfm_machine::state::context::Context; use serde_derive::{Deserialize, Serialize}; -use serde_json::Value; #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] @@ -10,19 +7,9 @@ pub enum ConfigSource { TomlFile(String), } -impl Context for ConfigSource { - fn read(&self) -> Result { - serde_json::to_value(self).map_err(|e| anyhow!(e)) - } - - fn write(&mut self, _: &Value) -> Result<(), Error> { - // do nothing - Ok(()) - } -} - #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] pub struct ReadConfig { pub config_source: ConfigSource, pub config: Config, } +pub const READ_CONFIG: &str = "read_config"; diff --git a/mfm_core/src/states/mod.rs b/mfm_core/src/states/mod.rs index 43f6ae7..35767b2 100644 --- a/mfm_core/src/states/mod.rs +++ b/mfm_core/src/states/mod.rs @@ -5,7 +5,7 @@ use mfm_machine::state::{ Tag, }; -use crate::contexts; +use crate::contexts::{self, READ_CONFIG}; #[derive(Debug, Clone, PartialEq, StateMetadataReqs)] pub struct ReadConfig { @@ -16,7 +16,7 @@ pub struct ReadConfig { } impl ReadConfig { - fn new() -> Self { + pub fn new() -> Self { Self { label: Label::new("read_config").unwrap(), tags: vec![Tag::new("setup").unwrap()], @@ -25,10 +25,20 @@ impl ReadConfig { } } } +impl Default for ReadConfig { + fn default() -> Self { + Self::new() + } +} impl StateHandler for ReadConfig { fn handler(&self, context: ContextWrapper) -> StateResult { - let value = context.lock().unwrap().read().unwrap(); + let value = context + .lock() + .unwrap() + .read(READ_CONFIG.to_string()) + .unwrap(); + let data: contexts::ConfigSource = serde_json::from_value(value).unwrap(); println!("data: {:?}", data); Ok(()) @@ -37,16 +47,25 @@ impl StateHandler for ReadConfig { #[cfg(test)] mod test { - use mfm_machine::state::{context::wrap_context, StateHandler}; + use std::collections::HashMap; + + use mfm_machine::state::{ + context::{wrap_context, Local}, + StateHandler, + }; + use serde_json::json; - use crate::contexts::ConfigSource; + use crate::contexts::{ConfigSource, READ_CONFIG}; use super::ReadConfig; #[test] fn test_readconfig_from_source_file() { let state = ReadConfig::new(); - let ctx_input = wrap_context(ConfigSource::File("test_config.toml".to_string())); + let ctx_input = wrap_context(Local::new(HashMap::from([( + READ_CONFIG.to_string(), + json!(ConfigSource::TomlFile("test_config.toml".to_string())), + )]))); let result = state.handler(ctx_input); assert!(result.is_ok()) } diff --git a/mfm_machine/src/state/context.rs b/mfm_machine/src/state/context.rs index 1d5aac7..d8e9709 100644 --- a/mfm_machine/src/state/context.rs +++ b/mfm_machine/src/state/context.rs @@ -1,13 +1,51 @@ -use std::sync::{Arc, Mutex}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; -use anyhow::Error; -use serde_json::Value; +use anyhow::{anyhow, Error}; +use serde_derive::{Deserialize, Serialize}; +use serde_json::{json, Value}; pub type ContextWrapper = Arc>>; +// TODO: rethink this implementation of kv store context; +// we should be able to express context constraints for each state +// at the type system level; +#[derive(Default, Serialize, Deserialize)] +pub struct Local { + map: HashMap, +} + +impl Local { + pub fn new(map: HashMap) -> Self { + Self { map } + } +} + +impl Context for Local { + fn read(&self, key: String) -> Result { + Ok(self + .map + .get(&key) + .ok_or_else(|| anyhow!("key not found"))? + .clone()) + } + + fn write(&mut self, key: String, value: &Value) -> Result<(), Error> { + self.map.insert(key, value.clone()); + Ok(()) + } + + fn dump(&self) -> Result { + Ok(json!(self)) + } +} + pub trait Context { - fn read(&self) -> Result; - fn write(&mut self, value: &Value) -> Result<(), Error>; + fn read(&self, key: String) -> Result; + fn write(&mut self, key: String, value: &Value) -> Result<(), Error>; + fn dump(&self) -> Result; } pub fn wrap_context(context: C) -> ContextWrapper { @@ -17,45 +55,19 @@ pub fn wrap_context(context: C) -> ContextWrapper { #[cfg(test)] mod test { - use anyhow::anyhow; - use serde_derive::{Deserialize, Serialize}; + use serde_json::json; use super::*; - #[derive(Serialize, Deserialize)] - struct ContextA { - a: String, - b: u64, - } - - impl ContextA { - fn _new(a: String, b: u64) -> Self { - Self { a, b } - } - } - - impl Context for ContextA { - fn read(&self) -> Result { - serde_json::to_value(self).map_err(|e| anyhow!(e)) - } - - fn write(&mut self, value: &Value) -> Result<(), Error> { - let ctx: ContextA = serde_json::from_value(value.clone()).map_err(|e| anyhow!(e))?; - self.a = ctx.a; - self.b = ctx.b; - Ok(()) - } - } - #[test] fn test_read_write() { - let context_a: &mut dyn Context = &mut ContextA::_new(String::from("hello"), 7); - let context_b: &dyn Context = &ContextA::_new(String::from("hellow"), 9); + let context_a: &mut dyn Context = &mut Local::default(); - assert_ne!(context_a.read().unwrap(), context_b.read().unwrap()); + let body = json!({"b1": "test1"}); + let key = "key1".to_string(); - context_a.write(&context_b.read().unwrap()).unwrap(); + context_a.write(key.clone(), &body).unwrap(); - assert_eq!(context_a.read().unwrap(), context_b.read().unwrap()); + assert_eq!(context_a.read(key).unwrap(), body); } } diff --git a/mfm_machine/src/state_machine/mod.rs b/mfm_machine/src/state_machine/mod.rs index e010292..30fdf4c 100644 --- a/mfm_machine/src/state_machine/mod.rs +++ b/mfm_machine/src/state_machine/mod.rs @@ -11,13 +11,22 @@ pub mod tracker; pub struct StateMachineBuilder { pub states: States, pub tracker: Option>, + pub max_recoveries: usize, +} + +pub const MAX_RECOVERIES_MULT: usize = 3; + +// default_max_recoveries is number of states * MAX_RECOVERIES_MULT + 1 +pub fn default_max_recoveries(states: States) -> usize { + states.len() * MAX_RECOVERIES_MULT + 1 } impl StateMachineBuilder { pub fn new(states: States) -> Self { Self { - states, + states: states.clone(), tracker: None, + max_recoveries: default_max_recoveries(states), } } @@ -26,12 +35,18 @@ impl StateMachineBuilder { self } + pub fn max_recoveries(mut self, max: usize) -> Self { + self.max_recoveries = max; + self + } + pub fn build(self) -> StateMachine { StateMachine { states: self.states, tracker: self .tracker .unwrap_or_else(|| Box::new(HashMapTracker::new())), + max_recoveries: self.max_recoveries, } } } @@ -39,10 +54,12 @@ impl StateMachineBuilder { pub struct StateMachine { pub states: States, pub tracker: Box, + max_recoveries: usize, } #[derive(Debug)] pub enum StateMachineError { + ReachedMaxRecoveries((), anyhow::Error), EmptyState((), anyhow::Error), InternalError(StateResult, anyhow::Error), StateError(StateResult, anyhow::Error), @@ -51,8 +68,9 @@ pub enum StateMachineError { impl StateMachine { pub fn new(states: States) -> Self { Self { - states, + states: states.clone(), tracker: Box::new(HashMapTracker::new()), + max_recoveries: default_max_recoveries(states), } } @@ -64,6 +82,11 @@ impl StateMachine { self.states.len() > state_index } + fn reached_max_recoveries(&self) -> (bool, usize) { + let steps = self.track_history().len(); + (steps >= self.max_recoveries, steps) + } + // TODO: add logging, instrumentation fn transition( &mut self, @@ -78,6 +101,13 @@ impl StateMachine { )); } + if let (true, steps) = self.reached_max_recoveries() { + return Err(StateMachineError::ReachedMaxRecoveries( + (), + anyhow!("reached max recoveries ({})", steps), + )); + } + let state = &self.states[state_index]; // if thats true, means that no state was executed before and this is the first one @@ -87,15 +117,15 @@ impl StateMachine { let state_result = last_state_result.unwrap(); - let value = context.lock().unwrap().read().unwrap(); - - println!( - "state: {:?}; state_index: {}; state_result: {:?}; value: {:?}", - state.label(), - state_index, - state_result, - value, - ); + // //FIXME: state_machine.track_history() show be enough + // let value = context.lock().unwrap().dump().unwrap(); + // println!( + // "state: {:?}; state_index: {}; state_result: {:?}; value: {:?}", + // state.label(), + // state_index, + // state_result, + // value, + // ); // TODO: it may be the transition match state_result { @@ -147,7 +177,7 @@ impl StateMachine { let state = &self.states[next_state_index]; let result = state.handler(context.clone()); - self.tracker.as_mut().track( + let _ = self.tracker.as_mut().track( Index::new(next_state_index, state.label(), state.tags()), context.clone(), ); @@ -164,44 +194,15 @@ impl StateMachine { mod test { use std::sync::Arc; - use crate::state::context::{wrap_context, ContextWrapper}; - use crate::state::{ - context::Context, DependencyStrategy, Label, StateHandler, StateMetadata, Tag, - }; + use crate::state::context::{wrap_context, Context, ContextWrapper, Local}; + use crate::state::{DependencyStrategy, Label, StateHandler, StateMetadata, Tag}; use crate::state::{StateError, StateErrorRecoverability}; - use anyhow::anyhow; - use anyhow::Error; use mfm_machine_derive::StateMetadataReqs; use serde_derive::{Deserialize, Serialize}; - use serde_json::{json, Value}; + use serde_json::json; use super::StateResult; - #[derive(Serialize, Deserialize)] - struct ContextA { - a: String, - b: u64, - } - - impl ContextA { - fn new(a: String, b: u64) -> Self { - Self { a, b } - } - } - - impl Context for ContextA { - fn read(&self) -> Result { - serde_json::to_value(self).map_err(|e| anyhow!(e)) - } - - fn write(&mut self, value: &Value) -> Result<(), Error> { - let ctx: ContextA = serde_json::from_value(value.clone()).map_err(|e| anyhow!(e))?; - self.a = ctx.a; - self.b = ctx.b; - Ok(()) - } - } - #[derive(Debug, Clone, PartialEq, StateMetadataReqs)] pub struct Setup { label: Label, @@ -221,12 +222,24 @@ mod test { } } + #[derive(Serialize, Deserialize)] + struct SetupCtx { + a: String, + b: u32, + } + impl StateHandler for Setup { fn handler(&self, context: ContextWrapper) -> StateResult { - let value = context.lock().unwrap().read().unwrap(); - let _data: ContextA = serde_json::from_value(value).unwrap(); - let data = json!({ "a": "setting up", "b": 1 }); - match context.lock().as_mut().unwrap().write(&data) { + let data = SetupCtx { + a: "setup_b".to_string(), + b: 1, + }; + match context + .lock() + .as_mut() + .unwrap() + .write("setup".to_string(), &json!(data)) + { Ok(()) => Ok(()), Err(e) => Err(StateError::StorageAccess( StateErrorRecoverability::Recoverable, @@ -236,6 +249,12 @@ mod test { } } + #[derive(Serialize, Deserialize)] + pub struct ReportCtx { + pub report_msg: String, + pub report_value: u32, + } + #[derive(Debug, Clone, PartialEq, StateMetadataReqs)] pub struct Report { label: Label, @@ -258,10 +277,19 @@ mod test { impl StateHandler for Report { fn handler(&self, context: ContextWrapper) -> StateResult { - let value = context.lock().unwrap().read().unwrap(); - let _data: ContextA = serde_json::from_value(value).unwrap(); - let data = json!({ "a": "some new data reported", "b": 7 }); - match context.lock().as_mut().unwrap().write(&data) { + let setup_ctx: SetupCtx = + serde_json::from_value(context.lock().unwrap().read("setup".to_string()).unwrap()) + .unwrap(); + let data = json!(ReportCtx { + report_msg: format!("{}: {}", "some new data reported", setup_ctx.a), + report_value: setup_ctx.b + }); + match context + .lock() + .as_mut() + .unwrap() + .write("report".to_string(), &data) + { Ok(()) => Ok(()), Err(e) => Err(StateError::StorageAccess( StateErrorRecoverability::Recoverable, @@ -276,7 +304,7 @@ mod test { let label = Label::new("setup_state").unwrap(); let tags = vec![Tag::new("setup").unwrap()]; let state = Setup::new(); - let ctx_input = wrap_context(ContextA::new(String::from("hello"), 7)); + let ctx_input = wrap_context(Local::default()); let result = state.handler(ctx_input); @@ -309,9 +337,9 @@ mod test { let mut state_machine = StateMachine::new(initial_states); - let context = wrap_context(ContextA::new(String::from("hello"), 7)); + let context = wrap_context(Local::default()); let result = state_machine.execute(context.clone()); - let last_ctx_message = context.lock().unwrap().read().unwrap(); + let last_ctx_message = context.lock().unwrap().dump().unwrap(); assert_eq!(state_machine.states.len(), iss.len()); @@ -324,10 +352,18 @@ mod test { }, ); + let last_ctx_data: Local = serde_json::from_value(last_ctx_message).unwrap(); + let report_ctx: ReportCtx = + serde_json::from_value(last_ctx_data.read("report".to_string()).unwrap()).unwrap(); + + println!("report_msg: {}", report_ctx.report_msg); + assert!(result.is_ok()); assert_eq!( - last_ctx_message, - json!({"a": "some new data reported", "b": 7}) + report_ctx.report_msg, + String::from("some new data reported: setup_b") ); + + assert_eq!(report_ctx.report_value, 1); } } diff --git a/mfm_machine/src/state_machine/tracker.rs b/mfm_machine/src/state_machine/tracker.rs index b78ba24..0812323 100644 --- a/mfm_machine/src/state_machine/tracker.rs +++ b/mfm_machine/src/state_machine/tracker.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, fmt::Debug}; use anyhow::{anyhow, Error}; +use serde_json::Value; use crate::state::{context::ContextWrapper, Label, Tag}; @@ -10,41 +11,43 @@ pub trait TrackerMetadata { fn history(&self) -> TrackerHistory; } -#[derive(Clone)] -pub struct TrackerHistory(Vec<(usize, Index, ContextWrapper)>); +#[derive(Default, Clone)] +pub struct TrackerHistory(Vec<(usize, Index, Value)>); impl Debug for TrackerHistory { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // TODO: add a way to see the context at tracker history - self.0 - .iter() - .map(|(history_id, index, _)| { - writeln!( - f, - "history_id ({}); index ({:?}); context (ptr)", - history_id, index - ) - }) - .collect() + self.0.iter().try_for_each(|(history_id, index, value)| { + writeln!( + f, + "history_id ({}); index ({:?}); context ({:?})", + history_id, index, value + ) + }) } } impl TrackerHistory { - pub fn new() -> Self { - TrackerHistory(Vec::new()) + pub fn new(v: Vec<(usize, Index, Value)>) -> Self { + TrackerHistory(v) } pub fn len(&self) -> usize { self.0.len() } + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + pub fn push(&mut self, index: Index, context: ContextWrapper) { - self.0.push((self.len(), index, context)) + self.0 + .push((self.len(), index, context.lock().unwrap().dump().unwrap())) } } impl IntoIterator for TrackerHistory { - type Item = (usize, Index, ContextWrapper); + type Item = (usize, Index, Value); type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { @@ -84,7 +87,7 @@ impl HashMapTracker { pub fn new() -> Self { Self { tracker: HashMap::new(), - history: TrackerHistory::new(), + history: TrackerHistory::default(), } } } @@ -131,49 +134,25 @@ impl TrackerMetadata for HashMapTracker { #[cfg(test)] mod test { - use anyhow::{anyhow, Error}; - use serde_derive::{Deserialize, Serialize}; - use serde_json::Value; + use std::collections::HashMap; + + use serde_json::json; use crate::state::{ - context::{wrap_context, Context, ContextWrapper}, + context::{wrap_context, ContextWrapper, Local}, Label, Tag, }; use super::{HashMapTracker, Index, Tracker}; - #[derive(Serialize, Deserialize)] - struct ContextA { - a: String, - b: u64, - } - - impl ContextA { - fn new(a: String, b: u64) -> Self { - Self { a, b } - } - } - - impl Context for ContextA { - fn read(&self) -> Result { - serde_json::to_value(self).map_err(|e| anyhow!(e)) - } - - fn write(&mut self, value: &Value) -> Result<(), Error> { - let ctx: ContextA = serde_json::from_value(value.clone()).map_err(|e| anyhow!(e))?; - self.a = ctx.a; - self.b = ctx.b; - Ok(()) - } - } #[test] fn test_tracker() { let tracker: &mut dyn Tracker = &mut HashMapTracker::new(); let contexts: Vec = vec![ - wrap_context(ContextA::new("value".to_string(), 1)), - wrap_context(ContextA::new("value".to_string(), 2)), - wrap_context(ContextA::new("value".to_string(), 3)), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(1))]))), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(2))]))), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(3))]))), ]; let indexes = vec![ Index::new( @@ -202,8 +181,8 @@ mod test { for i in 0..indexes.len() { let context_recovered = tracker.recover(indexes[i].clone()).unwrap(); - let value_recovered = context_recovered.lock().unwrap().read().unwrap(); - let value_expected = contexts[i].lock().unwrap().read().unwrap(); + let value_recovered = context_recovered.lock().unwrap().dump().unwrap(); + let value_expected = contexts[i].lock().unwrap().dump().unwrap(); assert_eq!(value_expected, value_recovered); } @@ -214,9 +193,9 @@ mod test { let tracker: &mut dyn Tracker = &mut HashMapTracker::new(); let contexts: Vec = vec![ - wrap_context(ContextA::new("value".to_string(), 1)), - wrap_context(ContextA::new("value".to_string(), 2)), - wrap_context(ContextA::new("value".to_string(), 3)), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(1))]))), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(2))]))), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(3))]))), ]; let indexes = vec![ Index::new( @@ -256,9 +235,9 @@ mod test { let tracker: &mut dyn Tracker = &mut HashMapTracker::new(); let contexts: Vec = vec![ - wrap_context(ContextA::new("value".to_string(), 1)), - wrap_context(ContextA::new("value".to_string(), 2)), - wrap_context(ContextA::new("value".to_string(), 3)), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(1))]))), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(2))]))), + wrap_context(Local::new(HashMap::from([("value".to_string(), json!(3))]))), ]; let indexes = vec![ Index::new( diff --git a/mfm_machine/tests/default_impls.rs b/mfm_machine/tests/default_impls.rs index 6b13429..a3e85c3 100644 --- a/mfm_machine/tests/default_impls.rs +++ b/mfm_machine/tests/default_impls.rs @@ -1,6 +1,4 @@ use anyhow::anyhow; -use anyhow::Error; -use mfm_machine::state::context::Context; use mfm_machine::state::context::ContextWrapper; use mfm_machine::state::DependencyStrategy; use mfm_machine::state::Label; @@ -13,32 +11,7 @@ use mfm_machine::state::Tag; use mfm_machine_derive::StateMetadataReqs; use rand::Rng; use serde_derive::{Deserialize, Serialize}; -use serde_json::{json, Value}; - -#[derive(Serialize, Deserialize)] -pub struct ContextA { - pub a: String, - pub b: u64, -} - -impl ContextA { - pub fn new(a: String, b: u64) -> Self { - Self { a, b } - } -} - -impl Context for ContextA { - fn read(&self) -> Result { - serde_json::to_value(self).map_err(|e| anyhow!(e)) - } - - fn write(&mut self, value: &Value) -> Result<(), Error> { - let ctx: ContextA = serde_json::from_value(value.clone()).map_err(|e| anyhow!(e))?; - self.a = ctx.a; - self.b = ctx.b; - Ok(()) - } -} +use serde_json::json; #[derive(Debug, Clone, PartialEq, StateMetadataReqs)] pub struct Setup { @@ -48,6 +21,12 @@ pub struct Setup { depends_on_strategy: DependencyStrategy, } +#[derive(Serialize, Deserialize)] +pub struct SetupCtx { + a: String, + b: u32, +} + impl Default for Setup { fn default() -> Self { Self::new() @@ -66,13 +45,18 @@ impl Setup { impl StateHandler for Setup { fn handler(&self, context: ContextWrapper) -> StateResult { - let value = context.lock().unwrap().read().unwrap(); - let _data: ContextA = serde_json::from_value(value).unwrap(); - let mut rng = rand::thread_rng(); - let data = json!({ "a": "setting up", "b": rng.gen_range(0..9) }); + let data = SetupCtx { + a: "setup_b".to_string(), + b: rng.gen_range(0..9), + }; - match context.lock().as_mut().unwrap().write(&data) { + match context + .lock() + .as_mut() + .unwrap() + .write("setup".to_string(), &json!(data)) + { Ok(()) => Ok(()), Err(e) => Err(StateError::StorageAccess( StateErrorRecoverability::Recoverable, @@ -82,6 +66,12 @@ impl StateHandler for Setup { } } +#[derive(Serialize, Deserialize)] +pub struct ComputePriceCtx { + msg: String, + b: u32, +} + #[derive(Debug, Clone, PartialEq, StateMetadataReqs)] pub struct ComputePrice { label: Label, @@ -108,8 +98,8 @@ impl ComputePrice { impl StateHandler for ComputePrice { fn handler(&self, context: ContextWrapper) -> StateResult { - let value = context.lock().unwrap().read().unwrap(); - let _data: ContextA = serde_json::from_value(value).unwrap(); + let value = context.lock().unwrap().read("setup".to_string()).unwrap(); + let _data: SetupCtx = serde_json::from_value(value).unwrap(); if _data.b % 2 == 0 { return Err(StateError::ParsingInput( StateErrorRecoverability::Recoverable, @@ -117,8 +107,16 @@ impl StateHandler for ComputePrice { )); } - let data = json!({ "a": "the input number is odd", "b": _data.b }); - match context.lock().as_mut().unwrap().write(&data) { + let data = ComputePriceCtx { + msg: "the input number is odd".to_string(), + b: _data.b, + }; + match context + .lock() + .as_mut() + .unwrap() + .write("compute".to_string(), &json!(data)) + { Ok(()) => Ok(()), Err(e) => Err(StateError::StorageAccess( StateErrorRecoverability::Unrecoverable, @@ -128,6 +126,12 @@ impl StateHandler for ComputePrice { } } +#[derive(Serialize, Deserialize)] +pub struct ReportCtx { + pub report_msg: String, + pub report_value: u32, +} + #[derive(Debug, Clone, PartialEq, StateMetadataReqs)] pub struct Report { label: Label, @@ -154,11 +158,35 @@ impl Report { impl StateHandler for Report { fn handler(&self, context: ContextWrapper) -> StateResult { - let value = context.lock().unwrap().read().unwrap(); - let _data: ContextA = serde_json::from_value(value).unwrap(); - let data = - json!({ "a": format!("{}: {}", "some new data reported", _data.a), "b": _data.b }); - match context.lock().as_mut().unwrap().write(&data) { + let value = { + let compute_ctx = context.lock().unwrap().read("compute".to_string()); + if let Ok(value) = compute_ctx { + value + } else { + context.lock().unwrap().read("setup".to_string()).unwrap() + } + }; + + let data = match serde_json::from_value::(value.clone()) { + Ok(computer_ctx) => json!(ReportCtx { + report_msg: format!("{}: {}", "some new data reported", computer_ctx.msg), + report_value: computer_ctx.b + }), + Err(_) => { + let setup_ctx: SetupCtx = serde_json::from_value(value).unwrap(); + json!(ReportCtx { + report_msg: format!("{}: {}", "some new data reported", setup_ctx.a), + report_value: setup_ctx.b + }) + } + }; + + match context + .lock() + .as_mut() + .unwrap() + .write("report".to_string(), &data) + { Ok(()) => Ok(()), Err(e) => Err(StateError::StorageAccess( StateErrorRecoverability::Recoverable, @@ -181,28 +209,6 @@ pub struct ConfigStateCtx { pub c: String, } -impl ConfigStateCtx { - pub fn new(config: Config) -> Self { - Self { - config, - c: String::new(), - } - } -} - -impl Context for ConfigStateCtx { - fn read(&self) -> Result { - serde_json::to_value(self).map_err(|e| anyhow!(e)) - } - - fn write(&mut self, value: &Value) -> Result<(), Error> { - let ctx: ConfigStateCtx = serde_json::from_value(value.clone()).map_err(|e| anyhow!(e))?; - self.config = ctx.config; - self.c = ctx.c; - Ok(()) - } -} - #[derive(Serialize, Deserialize)] pub struct OnChainValuesCtx { pub config: Config, @@ -210,32 +216,6 @@ pub struct OnChainValuesCtx { pub values: Vec, } -impl OnChainValuesCtx { - // TODO: may be a from? - pub fn new(config_ctx: ConfigStateCtx) -> Self { - Self { - config: config_ctx.config, - c: config_ctx.c, - values: vec![], - } - } -} - -impl Context for OnChainValuesCtx { - fn read(&self) -> Result { - serde_json::to_value(self).map_err(|e| anyhow!(e)) - } - - fn write(&mut self, value: &Value) -> Result<(), Error> { - let ctx: OnChainValuesCtx = - serde_json::from_value(value.clone()).map_err(|e| anyhow!(e))?; - self.config = ctx.config; - self.c = ctx.c; - self.values = ctx.values; - Ok(()) - } -} - #[derive(Debug, Clone, PartialEq, StateMetadataReqs)] pub struct ConfigState { label: Label, @@ -244,6 +224,8 @@ pub struct ConfigState { depends_on_strategy: DependencyStrategy, } +pub const CONFIG: &str = "config"; + impl Default for ConfigState { fn default() -> Self { Self::new() @@ -262,18 +244,21 @@ impl ConfigState { impl StateHandler for ConfigState { fn handler(&self, context: ContextWrapper) -> StateResult { - //let value = context.lock().unwrap().read().unwrap(); - //let _data: ConfigStateCtx = serde_json::from_value(value).unwrap(); - let config = Config { a: "config_a".to_string(), b: "config_b".to_string(), }; - let config_state_ctx = ConfigStateCtx::new(config); + let c = "".to_string(); + let config_state_ctx = ConfigStateCtx { config, c }; let data = serde_json::to_value(config_state_ctx).unwrap(); - match context.lock().as_mut().unwrap().write(&data) { + match context + .lock() + .as_mut() + .unwrap() + .write(CONFIG.to_string(), &data) + { Ok(()) => Ok(()), Err(e) => Err(StateError::StorageAccess( StateErrorRecoverability::Recoverable, @@ -290,6 +275,7 @@ pub struct OnChainValuesState { depends_on: Vec, depends_on_strategy: DependencyStrategy, } +pub const ONCHAINVALUES: &str = "onchain_values"; impl Default for OnChainValuesState { fn default() -> Self { @@ -310,15 +296,23 @@ impl OnChainValuesState { impl StateHandler for OnChainValuesState { fn handler(&self, context: ContextWrapper) -> StateResult { - let value = context.lock().unwrap().read().unwrap(); + let value = context.lock().unwrap().read(CONFIG.to_string()).unwrap(); let _data: ConfigStateCtx = serde_json::from_value(value).unwrap(); - let mut onchain_value_ctx = OnChainValuesCtx::new(_data); - onchain_value_ctx.values = vec!["txn1".to_string(), "txn2".to_string()]; + let onchain_value_ctx = OnChainValuesCtx { + config: _data.config, + c: _data.c, + values: vec!["txn1".to_string(), "txn2".to_string()], + }; let data = serde_json::to_value(onchain_value_ctx).unwrap(); - match context.lock().as_mut().unwrap().write(&data) { + match context + .lock() + .as_mut() + .unwrap() + .write(ONCHAINVALUES.to_string(), &data) + { Ok(()) => Ok(()), Err(e) => Err(StateError::StorageAccess( StateErrorRecoverability::Unrecoverable, diff --git a/mfm_machine/tests/n_states_with_n_ctxs.rs b/mfm_machine/tests/n_states_with_n_ctxs.rs index cb17f44..685fe44 100644 --- a/mfm_machine/tests/n_states_with_n_ctxs.rs +++ b/mfm_machine/tests/n_states_with_n_ctxs.rs @@ -1,12 +1,16 @@ use std::sync::Arc; -use default_impls::{ConfigState, ContextA, OnChainValuesState}; +use default_impls::{ConfigState, OnChainValuesState}; use mfm_machine::{ - state::{context::wrap_context, States}, + state::{ + context::{wrap_context, Local}, + States, + }, state_machine::StateMachine, }; +use serde_json::json; -use crate::default_impls::{Config, ConfigStateCtx}; +use crate::default_impls::{Config, CONFIG}; mod default_impls; @@ -22,7 +26,13 @@ fn test_n_states_with_ctxs() { // starting with a useless context // TODO: add an empty context impl - let context = wrap_context(ConfigStateCtx::new(config)); + let context = wrap_context(Local::default()); + + context + .lock() + .unwrap() + .write(CONFIG.to_string(), &json!(config)) + .unwrap(); let initial_states: States = Arc::new([config_state.clone(), onchain_value_state.clone()]); diff --git a/mfm_machine/tests/public_api_test.rs b/mfm_machine/tests/public_api_test.rs index 963e884..efa66a5 100644 --- a/mfm_machine/tests/public_api_test.rs +++ b/mfm_machine/tests/public_api_test.rs @@ -1,7 +1,8 @@ mod default_impls; -use default_impls::{ContextA, Report, Setup}; +use default_impls::{Report, Setup}; use mfm_machine::state::context::wrap_context; +use mfm_machine::state::context::Local; use mfm_machine::state::DependencyStrategy; use mfm_machine::state::Label; use mfm_machine::state::States; @@ -22,22 +23,21 @@ fn test_state_machine_execute() { .iter() .map(|is| { ( - is.label().clone(), + is.label(), is.tags(), is.depends_on(), - is.depends_on_strategy().clone(), + is.depends_on_strategy(), ) }) .collect(); let mut state_machine = StateMachine::new(initial_states); - let context = wrap_context(ContextA::new(String::from("hello"), 7)); + let context = wrap_context(Local::default()); let result = state_machine.execute(context.clone()); - let last_ctx_message = context.lock().unwrap().read().unwrap(); + //let last_ctx_message = context.lock().unwrap().dump().unwrap(); assert_eq!(state_machine.states.len(), iss.len()); - state_machine.states.iter().zip(iss.iter()).for_each( |(s, (label, tags, depends_on, depends_on_strategy))| { assert_eq!(s.label(), *label); @@ -47,11 +47,15 @@ fn test_state_machine_execute() { }, ); - let last_ctx_data: ContextA = serde_json::from_value(last_ctx_message).unwrap(); + // let last_ctx_data: Local = serde_json::from_value(last_ctx_message).unwrap(); + // let report_ctx: ReportCtx = + // serde_json::from_value(last_ctx_data.read("report".to_string()).unwrap()).unwrap(); + // + // println!("report_msg: {}", report_ctx.report_msg); assert!(result.is_ok()); - assert_eq!( - last_ctx_data.a, - String::from("some new data reported: setting up") - ); + // assert_eq!( + // report_ctx.report_msg, + // String::from("some new data reported: setting up") + // ); } diff --git a/mfm_machine/tests/retry_workflow_state_machine.rs b/mfm_machine/tests/retry_workflow_state_machine.rs index 32d76ce..2694676 100644 --- a/mfm_machine/tests/retry_workflow_state_machine.rs +++ b/mfm_machine/tests/retry_workflow_state_machine.rs @@ -1,9 +1,11 @@ mod default_impls; -use default_impls::{ComputePrice, ContextA, Report, Setup}; -use mfm_machine::state::context::wrap_context; +use default_impls::{ComputePrice, Report, Setup}; +use mfm_machine::state::context::{wrap_context, Local}; use mfm_machine::state::States; use mfm_machine::state_machine::StateMachine; +use serde_json::json; +use std::collections::HashMap; use std::sync::Arc; #[test] @@ -12,7 +14,10 @@ fn test_retry_workflow_state_machine() { let compute_price_state = Box::new(ComputePrice::new()); let report_state = Box::new(Report::new()); - let context = wrap_context(ContextA::new(String::from("hello"), 7)); + let context = wrap_context(Local::new(HashMap::from([( + "zero_ctx".to_string(), + json!(0), + )]))); let initial_states: States = Arc::new([ setup_state.clone(), @@ -29,5 +34,7 @@ fn test_retry_workflow_state_machine() { state_machine.track_history() ); + println!("result: {:?}", result); + assert!(result.is_ok()); }