Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace "weight" with "parameter".rs #193

Merged
merged 7 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod pre_training;
#[cfg(test)]
mod test_helpers;
mod training;
mod weight_clipper;
mod parameter_clipper;

pub use dataset::{FSRSItem, FSRSReview};
pub use error::{FSRSError, Result};
Expand Down
4 changes: 2 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::error::{FSRSError, Result};
use crate::inference::{Parameters, DECAY, FACTOR, S_MIN};
use crate::weight_clipper::clip_parameters;
use crate::parameter_clipper::clip_parameters;
use crate::DEFAULT_PARAMETERS;
use burn::backend::ndarray::NdArrayDevice;
use burn::backend::NdArray;
Expand Down Expand Up @@ -196,7 +196,7 @@ impl ModelConfig {
}

/// This is the main structure provided by this crate. It can be used
/// for both weight training, and for reviews.
/// for both parameter training, and for reviews.
#[derive(Debug, Clone)]
pub struct FSRS<B: Backend = NdArray> {
model: Option<Model<B>>,
Expand Down
6 changes: 3 additions & 3 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use ndarray::{s, Array1, Array2, Ix0, Ix1, SliceInfoElem, Zip};
use ndarray_rand::rand_distr::Distribution;
use ndarray_rand::RandomExt;
use rand::{
distributions::{Uniform, WeightedIndex},
distributions::{Uniform, parameteredIndex},
L-M-Sherlock marked this conversation as resolved.
Show resolved Hide resolved
rngs::StdRng,
SeedableRng,
};
Expand Down Expand Up @@ -148,10 +148,10 @@ pub fn simulate(
let mut cost_per_day = Array1::zeros(learn_span);

let first_rating_choices = [1, 2, 3, 4];
let first_rating_dist = WeightedIndex::new(first_rating_prob).unwrap();
let first_rating_dist = parameteredIndex::new(first_rating_prob).unwrap();

let review_rating_choices = [2, 3, 4];
let review_rating_dist = WeightedIndex::new(review_rating_prob).unwrap();
let review_rating_dist = parameteredIndex::new(review_rating_prob).unwrap();

let mut rng = StdRng::seed_from_u64(seed.unwrap_or(42));

Expand Down
6 changes: 3 additions & 3 deletions src/weight_clipper.rs → src/parameter_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use burn::tensor::{backend::Backend, Data, Tensor};

pub(crate) fn weight_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
pub(crate) fn parameter_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
let val = clip_parameters(&parameters.to_data().convert().value);
Tensor::from_data(
Data::new(val, parameters.shape()).convert(),
Expand Down Expand Up @@ -49,14 +49,14 @@ mod tests {
use burn::backend::ndarray::NdArrayDevice;

#[test]
fn weight_clipper_works() {
fn parameter_clipper_works() {
let device = NdArrayDevice::Cpu;
let tensor = Tensor::from_floats(
[0.0, -1000.0, 1000.0, 0.0, 1000.0, -1000.0, 1.0, 0.25, -0.1],
&device,
);

let param: Tensor<1> = weight_clipper(tensor);
let param: Tensor<1> = parameter_clipper(tensor);
let values = &param.to_data().value;

assert_eq!(
Expand Down
6 changes: 3 additions & 3 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::dataset::{split_filter_data, FSRSBatcher, FSRSDataset, FSRSItem};
use crate::error::Result;
use crate::model::{Model, ModelConfig};
use crate::pre_training::pretrain;
use crate::weight_clipper::weight_clipper;
use crate::parameter_clipper::parameter_clipper;
use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS};
use burn::backend::Autodiff;

Expand Down Expand Up @@ -267,7 +267,7 @@ impl<B: Backend> FSRS<B> {

if optimized_parameters
.iter()
.any(|weight: &f32| weight.is_infinite())
.any(|parameter: &f32| parameter.is_infinite())
{
return Err(FSRSError::InvalidInput);
}
Expand Down Expand Up @@ -358,7 +358,7 @@ fn train<B: AutodiffBackend>(
}
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
model.w = Param::from_tensor(weight_clipper(model.w.val()));
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
// info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr);
renderer.render_train(TrainingProgress {
progress,
Expand Down
Loading