diff --git a/src/model.rs b/src/model.rs index 75e1e81..06c959c 100644 --- a/src/model.rs +++ b/src/model.rs @@ -66,6 +66,10 @@ impl> Model { * ((-r + 1) * self.w().slice([14..15])).exp() } + fn mean_reversion(&self, init_d: Tensor, new_d: Tensor) -> Tensor { + self.w().slice([7..8]) * (init_d - new_d.clone()) + new_d + } + fn step( &self, i: usize, @@ -82,6 +86,7 @@ impl> Model { let r = self.power_forgetting_curve(delta_t, stability.clone()); // dbg!(&r); let new_d = difficulty - self.w().slice([6..7]) * (rating.clone() - 3); + let new_d = self.mean_reversion(self.w().slice([4..5]), new_d); let new_d = new_d.clamp(1.0, 10.0); // dbg!(&new_d); let s_recall = self.stability_after_success(