Skip to content

Commit

Permalink
generalize context for states and add state machine execution
Browse files Browse the repository at this point in the history
  • Loading branch information
willyrgf committed Oct 14, 2023
1 parent 5b71597 commit 4287567
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 59 deletions.
9 changes: 5 additions & 4 deletions mfm_machine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mfm_machine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ edition = "2021"
[dependencies]
anyhow = "1.0.75"
serde = "1.0.188"
serde_derive = "1.0.189"
serde_json = "1.0.107"
44 changes: 17 additions & 27 deletions mfm_machine/src/state/context.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,31 @@
use anyhow::{anyhow, Error};
use serde::{Deserialize, Serialize};

pub trait Context {
type Output: Deserialize<'static>;
type Input: Serialize;

fn read(&self) -> Self::Output;
fn write(&self, ctx_input: &Self::Input);
fn read<T: for<'de> Deserialize<'de>>(&self) -> Result<T, Error>;
fn write<T: Serialize>(&mut self, data: &T) -> Result<(), Error>;
}

#[derive(Debug)]
pub struct ContextInput {}

#[derive(Debug)]
pub struct ContextOutput {}

impl Context for ContextInput {
type Output = String;
type Input = String;

fn read(&self) -> Self::Output {
"hello".to_string()
}
pub struct RawContext {
data: String,
}

fn write(&self, ctx_input: &Self::Input) {
let _x = ctx_input;
impl RawContext {
pub fn new() -> Self {
Self {
data: "{}".to_string(),
}
}
}

impl Context for ContextOutput {
type Input = String;
type Output = String;

fn read(&self) -> Self::Output {
"hello".to_string()
impl Context for RawContext {
fn read<T: for<'de> Deserialize<'de>>(&self) -> Result<T, Error> {
serde_json::from_str(&self.data).map_err(|e| anyhow!("error on deserialize: {}", e))
}

fn write(&self, ctx_input: &Self::Input) {
let _x = ctx_input;
fn write<T: Serialize>(&mut self, data: &T) -> Result<(), Error> {
self.data = serde_json::to_string(data)?;
Ok(())
}
}
5 changes: 1 addition & 4 deletions mfm_machine/src/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ trait StateConfig {
}

trait StateHandler: StateConfig {
type InputContext: Context;
type OutputContext: Context;

fn handler(&self, context: Self::InputContext) -> Result<Self::OutputContext, Error>;
fn handler<C: Context>(&self, context: &mut C) -> Result<(), Error>;
}

// Those states are mfm-specific states, and should be moved to the app side
Expand Down
44 changes: 39 additions & 5 deletions mfm_machine/src/state/state_machine.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::usize;

use super::{State, StateHandler};
use anyhow::{anyhow, Error};

use super::{context::Context, State, StateError, StateHandler};

struct StateMachine<T> {
states: Vec<State<T>>,
current_state_index: usize,
}

impl<T> StateMachine<T>
Expand All @@ -14,11 +15,44 @@ where
pub fn new(initial_states: Vec<State<T>>) -> Self {
Self {
states: initial_states,
current_state_index: 0,
}
}

fn has_next_state(&self) -> bool {
self.states.len() > self.current_state_index + 1
fn has_state(&self, state_index: usize) -> bool {
self.states.len() > state_index
}

pub fn execute(&self, context: &mut impl Context, state_index: usize) -> Result<(), Error> {
if !self.has_state(0) {
return Err(anyhow!("no states defined to execute"));
}

if !self.has_state(state_index) {
return Ok(());
}

let current_state = &self.states[state_index];

let result = match current_state {
State::Setup(h) => h.handler(context),
State::Report(h) => h.handler(context),
};

// TODO: it may be the transition
match result {
Ok(()) => self.execute(context, state_index + 1),
Err(e) => {
match e.downcast::<StateError>() {
Ok(se) if se.is_recoverable() => {
// TODO: replay based on depends_on logic
// TODO: extract it to an generic func
Err(se.into())
}
Ok(se) => Err(se.into()),
// TODO: what we want to return in this case?
Err(ed) => Err(ed),
}
}
}
}
}
32 changes: 13 additions & 19 deletions mfm_machine/src/state/states.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use anyhow::Error;
use serde_derive::{Deserialize, Serialize};

use crate::state::{Label, StateConfig, StateHandler, Tag};

use super::{
context::{Context, ContextInput, ContextOutput},
context::{Context, RawContext},
DependencyStrategy,
};

Expand All @@ -25,16 +26,14 @@ impl SetupState {
}
}

impl StateHandler for SetupState {
type InputContext = ContextInput;
type OutputContext = ContextOutput;
#[derive(Debug, Deserialize, Serialize)]
struct SetupStateData {}

fn handler(&self, context: ContextInput) -> Result<ContextOutput, Error> {
let _data = context.read();
impl StateHandler for SetupState {
fn handler<C: Context>(&self, context: &mut C) -> Result<(), Error> {
let _data: SetupStateData = context.read().unwrap();
let data = "some new data".to_string();
let ctx_output = ContextOutput {};
ctx_output.write(&data);
Ok(ctx_output)
context.write(&data)
}
}

Expand Down Expand Up @@ -75,15 +74,10 @@ impl ReportState {
}

impl StateHandler for ReportState {
type InputContext = ContextInput;
type OutputContext = ContextOutput;

fn handler(&self, context: ContextInput) -> Result<ContextOutput, Error> {
let _data = context.read();
fn handler<C: Context>(&self, context: &mut C) -> Result<(), Error> {
let _data: String = context.read().unwrap();
let data = "some new data".to_string();
let ctx_output = ContextOutput {};
ctx_output.write(&data);
Ok(ctx_output)
context.write(&data)
}
}

Expand Down Expand Up @@ -116,10 +110,10 @@ mod test {
let label = Label::new("setup_state").unwrap();
let tags = vec![Tag::new("setup").unwrap()];
let state: State<SetupState> = State::Setup(SetupState::new());
let ctx_input = ContextInput {};
let mut ctx_input = RawContext::new();
match state {
State::Setup(t) => {
let result = t.handler(ctx_input);
let result = t.handler(&mut ctx_input);
assert!(result.is_ok());
assert_eq!(t.label(), &label);
assert_eq!(t.tags(), &tags);
Expand Down

0 comments on commit 4287567

Please sign in to comment.