From 2451c85453cbd90ea2ed4dbe89187f469e23efb6 Mon Sep 17 00:00:00 2001 From: Expertium <83031600+Expertium@users.noreply.github.com> Date: Tue, 21 May 2024 15:05:33 +0300 Subject: [PATCH 1/7] Replace "weight" with "parameter".rs --- src/weight_clipper.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weight_clipper.rs b/src/weight_clipper.rs index 13e2df0a..e7cd849c 100644 --- a/src/weight_clipper.rs +++ b/src/weight_clipper.rs @@ -4,7 +4,7 @@ use crate::{ }; use burn::tensor::{backend::Backend, Data, Tensor}; -pub(crate) fn weight_clipper(parameters: Tensor) -> Tensor { +pub(crate) fn parameter_clipper(parameters: Tensor) -> Tensor { let val = clip_parameters(¶meters.to_data().convert().value); Tensor::from_data( Data::new(val, parameters.shape()).convert(), @@ -49,14 +49,14 @@ mod tests { use burn::backend::ndarray::NdArrayDevice; #[test] - fn weight_clipper_works() { + fn parameter_clipper_works() { let device = NdArrayDevice::Cpu; let tensor = Tensor::from_floats( [0.0, -1000.0, 1000.0, 0.0, 1000.0, -1000.0, 1.0, 0.25, -0.1], &device, ); - let param: Tensor<1> = weight_clipper(tensor); + let param: Tensor<1> = parameter_clipper(tensor); let values = ¶m.to_data().value; assert_eq!( From 67ed649b1f09f2d3cd347d4eedf51a4641fc334d Mon Sep 17 00:00:00 2001 From: Expertium <83031600+Expertium@users.noreply.github.com> Date: Tue, 21 May 2024 15:07:01 +0300 Subject: [PATCH 2/7] Update lib.rs --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 4229b81b..eab0129b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ mod pre_training; #[cfg(test)] mod test_helpers; mod training; -mod weight_clipper; +mod parameter_clipper; pub use dataset::{FSRSItem, FSRSReview}; pub use error::{FSRSError, Result}; From 9a92081d22813c30e043f13af8b7ee9558b59138 Mon Sep 17 00:00:00 2001 From: Expertium <83031600+Expertium@users.noreply.github.com> Date: Tue, 21 May 2024 15:07:38 +0300 Subject: [PATCH 3/7] Update training.rs --- src/training.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/training.rs b/src/training.rs index 4a819376..804f858f 100644 --- a/src/training.rs +++ b/src/training.rs @@ -4,7 +4,7 @@ use crate::dataset::{split_filter_data, FSRSBatcher, FSRSDataset, FSRSItem}; use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::pre_training::pretrain; -use crate::weight_clipper::weight_clipper; +use crate::parameter_clipper::parameter_clipper; use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS}; use burn::backend::Autodiff; @@ -267,7 +267,7 @@ impl FSRS { if optimized_parameters .iter() - .any(|weight: &f32| weight.is_infinite()) + .any(|parameter: &f32| parameter.is_infinite()) { return Err(FSRSError::InvalidInput); } @@ -358,7 +358,7 @@ fn train( } let grads = GradientsParams::from_grads(gradients, &model); model = optim.step(lr, model, grads); - model.w = Param::from_tensor(weight_clipper(model.w.val())); + model.w = Param::from_tensor(parameter_clipper(model.w.val())); // info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr); renderer.render_train(TrainingProgress { progress, From 11c335ab66890c0b753cb5791ab02810e200f62e Mon Sep 17 00:00:00 2001 From: Expertium <83031600+Expertium@users.noreply.github.com> Date: Tue, 21 May 2024 15:08:04 +0300 Subject: [PATCH 4/7] Update model.rs --- src/model.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model.rs b/src/model.rs index 1052b012..59761946 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,6 +1,6 @@ use crate::error::{FSRSError, Result}; use crate::inference::{Parameters, DECAY, FACTOR, S_MIN}; -use crate::weight_clipper::clip_parameters; +use crate::parameter_clipper::clip_parameters; use crate::DEFAULT_PARAMETERS; use burn::backend::ndarray::NdArrayDevice; use burn::backend::NdArray; @@ -196,7 +196,7 @@ impl ModelConfig { } /// This is the main structure provided by this crate. It can be used -/// for both weight training, and for reviews. +/// for both parameter training, and for reviews. #[derive(Debug, Clone)] pub struct FSRS { model: Option>, From 69a63c5ddcca9a7f5fdbe909c462f83859cce272 Mon Sep 17 00:00:00 2001 From: Expertium <83031600+Expertium@users.noreply.github.com> Date: Tue, 21 May 2024 15:08:24 +0300 Subject: [PATCH 5/7] Update optimal_retention.rs --- src/optimal_retention.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 70d67bf6..c6a1762c 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -7,7 +7,7 @@ use ndarray::{s, Array1, Array2, Ix0, Ix1, SliceInfoElem, Zip}; use ndarray_rand::rand_distr::Distribution; use ndarray_rand::RandomExt; use rand::{ - distributions::{Uniform, WeightedIndex}, + distributions::{Uniform, parameteredIndex}, rngs::StdRng, SeedableRng, }; @@ -148,10 +148,10 @@ pub fn simulate( let mut cost_per_day = Array1::zeros(learn_span); let first_rating_choices = [1, 2, 3, 4]; - let first_rating_dist = WeightedIndex::new(first_rating_prob).unwrap(); + let first_rating_dist = parameteredIndex::new(first_rating_prob).unwrap(); let review_rating_choices = [2, 3, 4]; - let review_rating_dist = WeightedIndex::new(review_rating_prob).unwrap(); + let review_rating_dist = parameteredIndex::new(review_rating_prob).unwrap(); let mut rng = StdRng::seed_from_u64(seed.unwrap_or(42)); From 92681df1e2c7b2084733f4aff5caa67188f8d273 Mon Sep 17 00:00:00 2001 From: Expertium <83031600+Expertium@users.noreply.github.com> Date: Tue, 21 May 2024 16:39:32 +0300 Subject: [PATCH 6/7] Rename weight_clipper.rs to parameter_clipper.rs --- src/{weight_clipper.rs => parameter_clipper.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/{weight_clipper.rs => parameter_clipper.rs} (100%) diff --git a/src/weight_clipper.rs b/src/parameter_clipper.rs similarity index 100% rename from src/weight_clipper.rs rename to src/parameter_clipper.rs From 720b461c13661f1992ae570d72f4cf9825a1deda Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Wed, 22 May 2024 11:10:22 +0800 Subject: [PATCH 7/7] fix format & revert parameteredIndex --- src/lib.rs | 2 +- src/optimal_retention.rs | 6 +++--- src/training.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index eab0129b..13a79134 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,11 +9,11 @@ mod error; mod inference; mod model; mod optimal_retention; +mod parameter_clipper; mod pre_training; #[cfg(test)] mod test_helpers; mod training; -mod parameter_clipper; pub use dataset::{FSRSItem, FSRSReview}; pub use error::{FSRSError, Result}; diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index c6a1762c..70d67bf6 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -7,7 +7,7 @@ use ndarray::{s, Array1, Array2, Ix0, Ix1, SliceInfoElem, Zip}; use ndarray_rand::rand_distr::Distribution; use ndarray_rand::RandomExt; use rand::{ - distributions::{Uniform, parameteredIndex}, + distributions::{Uniform, WeightedIndex}, rngs::StdRng, SeedableRng, }; @@ -148,10 +148,10 @@ pub fn simulate( let mut cost_per_day = Array1::zeros(learn_span); let first_rating_choices = [1, 2, 3, 4]; - let first_rating_dist = parameteredIndex::new(first_rating_prob).unwrap(); + let first_rating_dist = WeightedIndex::new(first_rating_prob).unwrap(); let review_rating_choices = [2, 3, 4]; - let review_rating_dist = parameteredIndex::new(review_rating_prob).unwrap(); + let review_rating_dist = WeightedIndex::new(review_rating_prob).unwrap(); let mut rng = StdRng::seed_from_u64(seed.unwrap_or(42)); diff --git a/src/training.rs b/src/training.rs index 804f858f..532dabdb 100644 --- a/src/training.rs +++ b/src/training.rs @@ -3,8 +3,8 @@ use crate::cosine_annealing::CosineAnnealingLR; use crate::dataset::{split_filter_data, FSRSBatcher, FSRSDataset, FSRSItem}; use crate::error::Result; use crate::model::{Model, ModelConfig}; -use crate::pre_training::pretrain; use crate::parameter_clipper::parameter_clipper; +use crate::pre_training::pretrain; use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS}; use burn::backend::Autodiff;