diff --git a/src/lib.rs b/src/lib.rs index 4229b81..13a7913 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,11 +9,11 @@ mod error; mod inference; mod model; mod optimal_retention; +mod parameter_clipper; mod pre_training; #[cfg(test)] mod test_helpers; mod training; -mod weight_clipper; pub use dataset::{FSRSItem, FSRSReview}; pub use error::{FSRSError, Result}; diff --git a/src/model.rs b/src/model.rs index 1052b01..5976194 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,6 +1,6 @@ use crate::error::{FSRSError, Result}; use crate::inference::{Parameters, DECAY, FACTOR, S_MIN}; -use crate::weight_clipper::clip_parameters; +use crate::parameter_clipper::clip_parameters; use crate::DEFAULT_PARAMETERS; use burn::backend::ndarray::NdArrayDevice; use burn::backend::NdArray; @@ -196,7 +196,7 @@ impl ModelConfig { } /// This is the main structure provided by this crate. It can be used -/// for both weight training, and for reviews. +/// for both parameter training, and for reviews. #[derive(Debug, Clone)] pub struct FSRS { model: Option>, diff --git a/src/weight_clipper.rs b/src/parameter_clipper.rs similarity index 89% rename from src/weight_clipper.rs rename to src/parameter_clipper.rs index 13e2df0..e7cd849 100644 --- a/src/weight_clipper.rs +++ b/src/parameter_clipper.rs @@ -4,7 +4,7 @@ use crate::{ }; use burn::tensor::{backend::Backend, Data, Tensor}; -pub(crate) fn weight_clipper(parameters: Tensor) -> Tensor { +pub(crate) fn parameter_clipper(parameters: Tensor) -> Tensor { let val = clip_parameters(¶meters.to_data().convert().value); Tensor::from_data( Data::new(val, parameters.shape()).convert(), @@ -49,14 +49,14 @@ mod tests { use burn::backend::ndarray::NdArrayDevice; #[test] - fn weight_clipper_works() { + fn parameter_clipper_works() { let device = NdArrayDevice::Cpu; let tensor = Tensor::from_floats( [0.0, -1000.0, 1000.0, 0.0, 1000.0, -1000.0, 1.0, 0.25, -0.1], &device, ); - let param: Tensor<1> = weight_clipper(tensor); + let param: Tensor<1> = parameter_clipper(tensor); let values = ¶m.to_data().value; assert_eq!( diff --git a/src/training.rs b/src/training.rs index 4a81937..532dabd 100644 --- a/src/training.rs +++ b/src/training.rs @@ -3,8 +3,8 @@ use crate::cosine_annealing::CosineAnnealingLR; use crate::dataset::{split_filter_data, FSRSBatcher, FSRSDataset, FSRSItem}; use crate::error::Result; use crate::model::{Model, ModelConfig}; +use crate::parameter_clipper::parameter_clipper; use crate::pre_training::pretrain; -use crate::weight_clipper::weight_clipper; use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS}; use burn::backend::Autodiff; @@ -267,7 +267,7 @@ impl FSRS { if optimized_parameters .iter() - .any(|weight: &f32| weight.is_infinite()) + .any(|parameter: &f32| parameter.is_infinite()) { return Err(FSRSError::InvalidInput); } @@ -358,7 +358,7 @@ fn train( } let grads = GradientsParams::from_grads(gradients, &model); model = optim.step(lr, model, grads); - model.w = Param::from_tensor(weight_clipper(model.w.val())); + model.w = Param::from_tensor(parameter_clipper(model.w.val())); // info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr); renderer.render_train(TrainingProgress { progress,