From 3721cdfd00d74213dfa27dd4119d9bbb8478514d Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 29 Oct 2024 23:10:05 +0800 Subject: [PATCH] Fix/clamp stability in simulator (#250) * Fix/clamp stability in simulator * clip parameters in simulator --- src/inference.rs | 1 + src/model.rs | 4 ++-- src/optimal_retention.rs | 11 +++++++---- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index 7153344..9eb60a0 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -17,6 +17,7 @@ pub(crate) const DECAY: f64 = -0.5; /// (9/10) ^ (1 / DECAY) - 1 pub(crate) const FACTOR: f64 = 19f64 / 81f64; pub(crate) const S_MIN: f32 = 0.01; +pub(crate) const S_MAX: f32 = 36500.0; /// This is a slice for efficiency, but should always be 17 in length. pub type Parameters = [f32]; use itertools::izip; diff --git a/src/model.rs b/src/model.rs index 7e6c067..4e5f172 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,5 +1,5 @@ use crate::error::{FSRSError, Result}; -use crate::inference::{Parameters, DECAY, FACTOR, S_MIN}; +use crate::inference::{Parameters, DECAY, FACTOR, S_MAX, S_MIN}; use crate::parameter_clipper::clip_parameters; use crate::DEFAULT_PARAMETERS; use burn::backend::ndarray::NdArrayDevice; @@ -163,7 +163,7 @@ impl Model { ) }; MemoryStateTensors { - stability: new_s.clamp(S_MIN, 36500.0), + stability: new_s.clamp(S_MIN, S_MAX), difficulty: new_d, } } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index fc381ce..09e3941 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -1,6 +1,7 @@ use crate::error::{FSRSError, Result}; -use crate::inference::{next_interval, ItemProgress, Parameters, DECAY, FACTOR, S_MIN}; +use crate::inference::{next_interval, ItemProgress, Parameters, DECAY, FACTOR, S_MAX, S_MIN}; use crate::model::check_and_fill_parameters; +use crate::parameter_clipper::clip_parameters; use crate::FSRS; use burn::tensor::backend::Backend; use itertools::{izip, Itertools}; @@ -71,12 +72,13 @@ impl Default for SimulatorConfig { fn stability_after_success(w: &[f32], s: f32, r: f32, d: f32, rating: usize) -> f32 { let hard_penalty = if rating == 2 { w[15] } else { 1.0 }; let easy_bonus = if rating == 4 { w[16] } else { 1.0 }; - s * (f32::exp(w[8]) + (s * (f32::exp(w[8]) * (11.0 - d) * s.powf(-w[9]) * (f32::exp((1.0 - r) * w[10]) - 1.0) * hard_penalty) - .mul_add(easy_bonus, 1.0) + .mul_add(easy_bonus, 1.0)) + .clamp(S_MIN, S_MAX) } fn stability_after_failure(w: &[f32], s: f32, r: f32, d: f32) -> f32 { @@ -85,7 +87,7 @@ fn stability_after_failure(w: &[f32], s: f32, r: f32, d: f32) -> f32 { } fn stability_short_term(w: &[f32], s: f32, rating_offset: f32, session_len: f32) -> f32 { - s * (w[17] * (rating_offset + session_len * w[18])).exp() + (s * (w[17] * (rating_offset + session_len * w[18])).exp()).clamp(S_MIN, S_MAX) } fn init_d(w: &[f32], rating: usize) -> f32 { @@ -132,6 +134,7 @@ pub fn simulate( existing_cards: Option>, ) -> Result<(Array1, Array1, Array1, Array1), FSRSError> { let w = &check_and_fill_parameters(w)?; + let w = &clip_parameters(w); let SimulatorConfig { deck_size, learn_span,