Skip to content

Commit

Permalink
Fix/simulator crashes if no history (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Sep 7, 2024
1 parent 517ecd2 commit c65fe5c
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.2.0"
version = "1.2.1"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
3 changes: 2 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use snafu::Snafu;

#[derive(Snafu, Debug)]
#[derive(Snafu, Debug, PartialEq)]
pub enum FSRSError {
NotEnoughData,
Interrupted,
InvalidParameters,
OptimalNotFound,
InvalidInput,
InvalidDeckSize,
}

pub type Result<T, E = FSRSError> = std::result::Result<T, E>;
62 changes: 59 additions & 3 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ pub fn simulate(
learn_limit,
review_limit,
} = config.clone();
if deck_size == 0 {
return Err(FSRSError::InvalidDeckSize);
}
let mut card_table = Array2::zeros((Column::COUNT, deck_size));
card_table
.slice_mut(s![Column::Due, ..])
Expand All @@ -179,6 +182,9 @@ pub fn simulate(

// fill card table based on existing_cards
if let Some(existing_cards) = existing_cards {
if existing_cards.len() > deck_size {
return Err(FSRSError::InvalidDeckSize);
}
for (i, card) in existing_cards.into_iter().enumerate() {
card_table[[Column::Difficulty as usize, i]] = card.difficulty;
card_table[[Column::Stability as usize, i]] = card.stability;
Expand Down Expand Up @@ -665,6 +671,9 @@ pub fn extract_simulator_config(
day_cutoff: i64,
smooth: bool,
) -> SimulatorConfig {
if df.is_empty() {
return SimulatorConfig::default();
}
/*
def rating_counts(x):
tmp = defaultdict(int, x.value_counts().to_dict())
Expand Down Expand Up @@ -790,18 +799,23 @@ pub fn extract_simulator_config(
.collect::<HashMap<_, _>>()
};
// [button_usage_dict.get((1, i), 0) for i in range(1, 5)]
let learn_buttons: [i64; 4] = (1..5)
let mut learn_buttons: [i64; 4] = (1..=4)
.map(|i| button_usage_dict.get(&(1, i)).copied().unwrap_or_default())
.collect_vec()
.try_into()
.unwrap();
if learn_buttons.iter().all(|&x| x == 0) {
learn_buttons = [1, 1, 1, 1];
}
// [button_usage_dict.get((2, i), 0) for i in range(1, 5)]
let review_buttons: [i64; 4] = (1..5)
let mut review_buttons: [i64; 4] = (1..=4)
.map(|i| button_usage_dict.get(&(2, i)).copied().unwrap_or_default())
.collect_vec()
.try_into()
.unwrap();

if review_buttons.iter().skip(1).all(|&x| x == 0) {
review_buttons = [review_buttons[0], 1, 1, 1];
}
// self.first_rating_prob = self.learn_buttons / self.learn_buttons.sum()
let mut first_rating_prob: [f32; 4] = learn_buttons
.iter()
Expand Down Expand Up @@ -1055,6 +1069,42 @@ mod tests {
Ok(())
}

#[test]
fn simulate_with_zero_card() -> Result<()> {
let config = SimulatorConfig {
deck_size: 0,
..Default::default()
};
let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None);
assert_eq!(results.unwrap_err(), FSRSError::InvalidDeckSize);
Ok(())
}

#[test]
fn simulate_with_existing_cards_with_wrong_deck_size() -> Result<()> {
let config = SimulatorConfig {
deck_size: 1,
..Default::default()
};
let cards = vec![
Card {
difficulty: 5.0,
stability: 5.0,
last_date: -5.0,
due: 0.0,
},
Card {
difficulty: 5.0,
stability: 2.0,
last_date: -2.0,
due: 0.0,
},
];
let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, Some(cards));
assert_eq!(results.unwrap_err(), FSRSError::InvalidDeckSize);
Ok(())
}

#[test]
fn optimal_retention() -> Result<()> {
let learn_span = 1000;
Expand Down Expand Up @@ -1129,4 +1179,10 @@ mod tests {
}
);
}

#[test]
fn extract_simulator_config_without_revlog() {
let simulator_config = extract_simulator_config(vec![], 0, true);
assert_eq!(simulator_config, SimulatorConfig::default());
}
}

0 comments on commit c65fe5c

Please sign in to comment.