Skip to content

Commit

Permalink
test: making the state_machine recoverable and fault-tolerant
Browse files Browse the repository at this point in the history
  • Loading branch information
willyrgf committed Nov 25, 2023
1 parent 834c34e commit de9b7a6
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 61 deletions.
21 changes: 0 additions & 21 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion mfm_machine/src/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub mod context;

use context::Context;

use self::context::ContextWrapper;

#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct Tag(&'static str);

Expand Down Expand Up @@ -61,7 +63,7 @@ pub trait StateMetadata {
pub type StateResult = Result<(), StateError>;

pub trait StateHandler: StateMetadata + Send + Sync {
fn handler(&self, context: Arc<Mutex<Box<dyn Context>>>) -> StateResult;
fn handler(&self, context: ContextWrapper) -> StateResult;
}

pub type States = Arc<[Box<dyn StateHandler>]>;
Expand Down
11 changes: 0 additions & 11 deletions mfm_machine/src/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,11 @@ impl StateMachine {

let last_index_of_first_dep = indexes_state_deps.last().unwrap().clone();

println!(
"trying to recover it from {:?}::{} to {:?}::{}",
state.label(),
state_index,
last_index_of_first_dep.state_label,
last_index_of_first_dep.state_index
);

let last_index_state_ctx =
tracker.recover(last_index_of_first_dep.clone()).unwrap();

println!("are we waiting for some lock here??");
context = last_index_state_ctx.clone();

println!("are we waiting for some lock here???");

// TODO: design the possible state recoverability and default cases
Ok((last_index_of_first_dep.state_index, context))
} else {
Expand Down
40 changes: 23 additions & 17 deletions mfm_machine/tests/default_impls.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use std::sync::Arc;
use std::sync::Mutex;

use anyhow::anyhow;
use anyhow::Error;
use mfm_machine::state::context::Context;
use mfm_machine::state::context::ContextWrapper;
use mfm_machine::state::DependencyStrategy;
use mfm_machine::state::Label;
use mfm_machine::state::StateConfig;
use mfm_machine::state::StateError;
use mfm_machine::state::StateErrorRecoverability;
use mfm_machine::state::StateHandler;
use mfm_machine::state::StateMetadata;
use mfm_machine::state::StateResult;
use mfm_machine::state::Tag;
use mfm_machine::StateConfigReqs;
use mfm_machine_macros::StateMetadataReqs;
use rand::Rng;
use serde_derive::{Deserialize, Serialize};
use serde_json::{json, Value};
Expand Down Expand Up @@ -39,7 +43,7 @@ impl Context for ContextA {
}
}

#[derive(Debug, Clone, PartialEq, StateConfigReqs)]
#[derive(Debug, Clone, PartialEq, StateMetadataReqs)]
pub struct Setup {
label: Label,
tags: Vec<Tag>,
Expand All @@ -52,7 +56,6 @@ impl Default for Setup {
Self::new()
}
}

impl Setup {
pub fn new() -> Self {
Self {
Expand All @@ -65,11 +68,14 @@ impl Setup {
}

impl StateHandler for Setup {
fn handler(&self, context: &mut dyn Context) -> StateResult {
let _data: ContextA = serde_json::from_value(context.read_input().unwrap()).unwrap();
fn handler(&self, context: ContextWrapper) -> StateResult {
let value = context.lock().unwrap().read_input().unwrap();
let _data: ContextA = serde_json::from_value(value).unwrap();

let mut rng = rand::thread_rng();
let data = json!({ "a": "setting up", "b": rng.gen_range(0..9) });
match context.write_output(&data) {

match context.lock().as_mut().unwrap().write_output(&data) {
Ok(()) => Ok(()),
Err(e) => Err(StateError::StorageAccess(
StateErrorRecoverability::Recoverable,
Expand All @@ -79,7 +85,7 @@ impl StateHandler for Setup {
}
}

#[derive(Debug, Clone, PartialEq, StateConfigReqs)]
#[derive(Debug, Clone, PartialEq, StateMetadataReqs)]
pub struct ComputePrice {
label: Label,
tags: Vec<Tag>,
Expand All @@ -92,7 +98,6 @@ impl Default for ComputePrice {
Self::new()
}
}

impl ComputePrice {
pub fn new() -> Self {
Self {
Expand All @@ -105,8 +110,9 @@ impl ComputePrice {
}

impl StateHandler for ComputePrice {
fn handler(&self, context: &mut dyn Context) -> StateResult {
let _data: ContextA = serde_json::from_value(context.read_input().unwrap()).unwrap();
fn handler(&self, context: ContextWrapper) -> StateResult {
let value = context.lock().unwrap().read_input().unwrap();
let _data: ContextA = serde_json::from_value(value).unwrap();
if _data.b % 2 == 0 {
return Err(StateError::ParsingInput(
StateErrorRecoverability::Recoverable,
Expand All @@ -115,7 +121,7 @@ impl StateHandler for ComputePrice {
}

let data = json!({ "a": "the input number is odd", "b": _data.b });
match context.write_output(&data) {
match context.lock().as_mut().unwrap().write_output(&data) {
Ok(()) => Ok(()),
Err(e) => Err(StateError::StorageAccess(
StateErrorRecoverability::Unrecoverable,
Expand All @@ -125,7 +131,7 @@ impl StateHandler for ComputePrice {
}
}

#[derive(Debug, Clone, PartialEq, StateConfigReqs)]
#[derive(Debug, Clone, PartialEq, StateMetadataReqs)]
pub struct Report {
label: Label,
tags: Vec<Tag>,
Expand All @@ -138,7 +144,6 @@ impl Default for Report {
Self::new()
}
}

impl Report {
pub fn new() -> Self {
Self {
Expand All @@ -151,11 +156,12 @@ impl Report {
}

impl StateHandler for Report {
fn handler(&self, context: &mut dyn Context) -> StateResult {
let _data: ContextA = serde_json::from_value(context.read_input().unwrap()).unwrap();
fn handler(&self, context: ContextWrapper) -> StateResult {
let value = context.lock().unwrap().read_input().unwrap();
let _data: ContextA = serde_json::from_value(value).unwrap();
let data =
json!({ "a": format!("{}: {}", "some new data reported", _data.a), "b": _data.b });
match context.write_output(&data) {
match context.lock().as_mut().unwrap().write_output(&data) {
Ok(()) => Ok(()),
Err(e) => Err(StateError::StorageAccess(
StateErrorRecoverability::Recoverable,
Expand Down
13 changes: 6 additions & 7 deletions mfm_machine/tests/public_api_test.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
mod default_impls;

use default_impls::{ContextA, Report, Setup};
use mfm_machine::state::context::Context;
use mfm_machine::state::context::wrap_context;
use mfm_machine::state::DependencyStrategy;
use mfm_machine::state::Label;
use mfm_machine::state::StateHandler;
use mfm_machine::state::States;
use mfm_machine::state::Tag;
use mfm_machine::state_machine::StateMachine;

Expand All @@ -15,8 +15,7 @@ fn test_state_machine_execute() {
let setup_state = Box::new(Setup::new());
let report_state = Box::new(Report::new());

let initial_states: Arc<[Box<dyn StateHandler>]> =
Arc::new([setup_state.clone(), report_state.clone()]);
let initial_states: States = Arc::new([setup_state.clone(), report_state.clone()]);
let initial_states_cloned = initial_states.clone();

let iss: Vec<(Label, Vec<Tag>, Vec<Tag>, DependencyStrategy)> = initial_states_cloned
Expand All @@ -33,9 +32,9 @@ fn test_state_machine_execute() {

let mut state_machine = StateMachine::new(initial_states);

let context: &mut dyn Context = &mut ContextA::new(String::from("hello"), 7);
let result = state_machine.execute(context);
let last_ctx_message = context.read_input().unwrap();
let context = wrap_context(ContextA::new(String::from("hello"), 7));
let result = state_machine.execute(context.clone());
let last_ctx_message = context.lock().unwrap().read_input().unwrap();

assert_eq!(state_machine.states.len(), iss.len());

Expand Down
8 changes: 4 additions & 4 deletions mfm_machine/tests/retry_workflow_state_machine_test.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mod default_impls;

use default_impls::{ComputePrice, ContextA, Report, Setup};
use mfm_machine::state::context::Context;
use mfm_machine::state::StateHandler;
use mfm_machine::state::context::wrap_context;
use mfm_machine::state::States;
use mfm_machine::state_machine::StateMachine;
use std::sync::Arc;

Expand All @@ -12,9 +12,9 @@ fn test_retry_workflow_state_machine() {
let compute_price_state = Box::new(ComputePrice::new());
let report_state = Box::new(Report::new());

let context: &mut dyn Context = &mut ContextA::new(String::from("hello"), 7);
let context = wrap_context(ContextA::new(String::from("hello"), 7));

let initial_states: Arc<[Box<dyn StateHandler>]> = Arc::new([
let initial_states: States = Arc::new([
setup_state.clone(),
compute_price_state.clone(),
report_state.clone(),
Expand Down

0 comments on commit de9b7a6

Please sign in to comment.