diff --git a/src/model.rs b/src/model.rs index ab5b2f9..83c38bd 100644 --- a/src/model.rs +++ b/src/model.rs @@ -250,6 +250,8 @@ pub(crate) fn check_and_fill_parameters(parameters: &Parameters) -> Result DEFAULT_PARAMETERS.to_vec(), 17 => { let mut parameters = parameters.to_vec(); + parameters[4] = parameters[5].mul_add(2.0, parameters[4]); + parameters[5] = parameters[5].mul_add(3.0, 1.0).ln() / 3.0; parameters.extend_from_slice(&[0.0, 0.0]); parameters } @@ -274,6 +276,22 @@ mod tests { assert_eq!(model.w.val().to_data(), Data::from(DEFAULT_PARAMETERS)) } + #[test] + fn convert_parameters() { + let fsrs4dot5_param = vec![ + 0.4, 0.6, 2.4, 5.8, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, + 0.29, 2.61, + ]; + let fsrs5_param = check_and_fill_parameters(&fsrs4dot5_param).unwrap(); + assert_eq!( + fsrs5_param, + vec![ + 0.4, 0.6, 2.4, 5.8, 6.81, 0.44675013, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, + 0.34, 1.26, 0.29, 2.61, 0.0, 0.0, + ] + ) + } + #[test] fn power_forgetting_curve() { let device = NdArrayDevice::Cpu; diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index e04737a..f824a21 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -1144,9 +1144,9 @@ mod tests { learn_limit, ..Default::default() }; - let optimal_retention = fsrs - .optimal_retention(&config, &DEFAULT_PARAMETERS[..17], |_v| true) - .unwrap(); + let mut param = DEFAULT_PARAMETERS[..17].to_vec(); + param.extend_from_slice(&[0.0, 0.0]); + let optimal_retention = fsrs.optimal_retention(&config, ¶m, |_v| true).unwrap(); assert_eq!(optimal_retention, 0.85450846); Ok(()) }