Skip to content

Commit

Permalink
clippy::nursery
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Mar 16, 2024
1 parent cf5578d commit 0d342dd
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 47 deletions.
27 changes: 3 additions & 24 deletions src/batch_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{

use crate::{dataset::FSRSDataset, FSRSItem};

pub(crate) struct BatchShuffledDataset<I> {
pub struct BatchShuffledDataset<I> {
dataset: Arc<FSRSDataset>,
indices: Vec<usize>,
input: PhantomData<I>,
Expand Down Expand Up @@ -130,11 +130,10 @@ where
// When starting a new iteration, we first check if the dataloader was created with an rng,
// implying that we should shuffle the dataset beforehand, while advancing the current
// rng to ensure that each new iteration shuffles the dataset differently.
let mut rng = self.rng.lock().unwrap();
let dataset = Arc::new(BatchShuffledDataset::with_seed(
self.dataset.clone(),
self.batch_size,
rng.sample(Standard),
self.rng.lock().unwrap().sample(Standard),
));
Box::new(BatchShuffledDataloaderIterator::new(
self.strategy.new_like(),
Expand Down Expand Up @@ -209,7 +208,6 @@ impl<I, O> DataLoaderIterator<O> for BatchShuffledDataloaderIterator<I, O> {

/// A builder for data loaders.
pub struct BatchShuffledDataLoaderBuilder<I, O> {
strategy: Option<Box<dyn BatchStrategy<I>>>,
batcher: Arc<dyn Batcher<I, O>>,
}

Expand All @@ -234,25 +232,9 @@ where
{
Self {
batcher: Arc::new(batcher),
strategy: None,
}
}

/// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy)
/// will be used.
///
/// # Arguments
///
/// * `batch_size` - The batch size.
///
/// # Returns
///
/// The data loader builder.
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
self
}

/// Builds the data loader.
///
/// # Arguments
Expand All @@ -271,10 +253,7 @@ where
let dataset = Arc::new(dataset);

let rng = StdRng::seed_from_u64(seed);
let strategy = match self.strategy {
Some(strategy) => strategy,
None => Box::new(FixBatchStrategy::new(1)),
};
let strategy = Box::new(FixBatchStrategy::new(batch_size));

Arc::new(BatchShuffledDataLoader::new(
strategy,
Expand Down
2 changes: 1 addition & 1 deletion src/cosine_annealing.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn::{lr_scheduler::LrScheduler, tensor::backend::Backend, LearningRate};
#[derive(Clone, Debug)]
pub(crate) struct CosineAnnealingLR {
pub struct CosineAnnealingLR {
t_max: f64,
eta_min: f64,
init_lr: LearningRate,
Expand Down
6 changes: 3 additions & 3 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl FSRSItem {
}
}

pub(crate) struct FSRSBatcher<B: Backend> {
pub struct FSRSBatcher<B: Backend> {
device: B::Device,
}

Expand All @@ -61,7 +61,7 @@ impl<B: Backend> FSRSBatcher<B> {
}

#[derive(Debug, Clone)]
pub(crate) struct FSRSBatch<B: Backend> {
pub struct FSRSBatch<B: Backend> {
pub t_historys: Tensor<B, 2, Float>,
pub r_historys: Tensor<B, 2, Float>,
pub delta_ts: Tensor<B, 1, Float>,
Expand Down Expand Up @@ -133,7 +133,7 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
}
}

pub(crate) struct FSRSDataset {
pub struct FSRSDataset {
items: Vec<FSRSItem>,
}

Expand Down
8 changes: 4 additions & 4 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use crate::model::Model;
use crate::training::BCELoss;
use crate::{FSRSError, FSRSItem};
use burn::tensor::ElementConversion;
pub(crate) const DECAY: f64 = -0.5;
pub const DECAY: f64 = -0.5;
/// (9/10) ^ (1 / DECAY) - 1
pub(crate) const FACTOR: f64 = 19f64 / 81f64;
pub(crate) const S_MIN: f32 = 0.01;
pub const FACTOR: f64 = 19f64 / 81f64;
pub const S_MIN: f32 = 0.01;
/// This is a slice for efficiency, but should always be 17 in length.
pub type Parameters = [f32];
use itertools::izip;
Expand Down Expand Up @@ -292,7 +292,7 @@ impl<B: Backend> FSRS<B> {
total: items.len(),
};
let model_self = self.model();
let fsrs_other = FSRS::<B>::new_with_backend(Some(parameters), self.device())?;
let fsrs_other = Self::new_with_backend(Some(parameters), self.device())?;
let model_other = fsrs_other.model();
for chunk in items.chunks(512) {
let batch = batcher.batch(chunk.to_vec());
Expand Down
6 changes: 3 additions & 3 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct Model<B: Backend> {
pub config: ModelConfig,
}

pub(crate) trait Get<B: Backend, const N: usize> {
pub trait Get<B: Backend, const N: usize> {
fn get(&self, n: usize) -> Tensor<B, N>;
}

Expand Down Expand Up @@ -177,7 +177,7 @@ impl<B: Backend> Model<B> {
}

#[derive(Debug, Clone)]
pub(crate) struct MemoryStateTensors<B: Backend> {
pub struct MemoryStateTensors<B: Backend> {
pub stability: Tensor<B, 1>,
pub difficulty: Tensor<B, 1>,
}
Expand Down Expand Up @@ -240,7 +240,7 @@ impl<B: Backend> FSRS<B> {
}
}

pub(crate) fn parameters_to_model<B: Backend>(parameters: &Parameters) -> Model<B> {
pub fn parameters_to_model<B: Backend>(parameters: &Parameters) -> Model<B> {
let config = ModelConfig::default();
let mut model = Model::new(config);
model.w = Param::from(Tensor::from_floats(
Expand Down
4 changes: 2 additions & 2 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ pub fn simulate(
let mut retrievability = Array1::zeros(deck_size); // Create an array for retrievability

fn power_forgetting_curve(t: f64, s: f64) -> f64 {
(t / s * FACTOR + 1.0).powf(DECAY)
(t / s).mul_add(FACTOR, 1.0).powf(DECAY)
}

// Calculate retrievability for entries where has_learned is true
Expand Down Expand Up @@ -505,7 +505,7 @@ impl<B: Backend> FSRS<B> {
let tol2 = 2.0 * tol1;
let xmid = 0.5 * (a + b);
// check for convergence
if (x - xmid).abs() < (tol2 - 0.5 * (b - a)) {
if (x - xmid).abs() < 0.5f64.mul_add(-(b - a), tol2) {
break;
}
if deltax.abs() <= tol1 {
Expand Down
2 changes: 1 addition & 1 deletion src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn loss(
logloss + l1
}

pub(crate) const INIT_S_MAX: f32 = 100.0;
pub const INIT_S_MAX: f32 = 100.0;

fn search_parameters(
mut pretrainset: HashMap<FirstRating, Vec<AverageRecall>>,
Expand Down
16 changes: 9 additions & 7 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl CombinedProgressState {
self.splits.iter().map(|s| s.total()).sum()
}

pub fn finished(&self) -> bool {
pub const fn finished(&self) -> bool {
self.finished
}
}
Expand Down Expand Up @@ -166,7 +166,7 @@ impl MetricsRenderer for ProgressCollector {
}

#[derive(Config)]
pub(crate) struct TrainingConfig {
pub struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 5)]
Expand Down Expand Up @@ -300,7 +300,7 @@ impl<B: Backend> FSRS<B> {
}

pub fn benchmark(&self, train_set: Vec<FSRSItem>, test_set: Vec<FSRSItem>) -> Vec<f32> {
let average_recall = calculate_average_recall(&train_set.clone());
let average_recall = calculate_average_recall(&train_set);
let (pre_train_set, next_train_set) = train_set
.into_iter()
.partition(|item| item.reviews.len() == 2);
Expand Down Expand Up @@ -330,11 +330,13 @@ fn train<B: AutodiffBackend>(
// Training data
let iterations = (trainset.len() / config.batch_size + 1) * config.num_epochs;
let batcher_train = FSRSBatcher::<B>::new(device.clone());
let dataloader_train = BatchShuffledDataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.build(FSRSDataset::from(trainset), config.batch_size, config.seed);
let dataloader_train = BatchShuffledDataLoaderBuilder::new(batcher_train).build(
FSRSDataset::from(trainset),
config.batch_size,
config.seed,
);

let batcher_valid = FSRSBatcher::new(device.clone());
let batcher_valid = FSRSBatcher::new(device);
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.build(FSRSDataset::from(testset.clone()));
Expand Down
4 changes: 2 additions & 2 deletions src/weight_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ use crate::{
};
use burn::tensor::{backend::Backend, Data, Tensor};

pub(crate) fn weight_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
pub fn weight_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
let val = clip_parameters(&parameters.to_data().convert().value);
Tensor::from_data(
Data::new(val, parameters.shape()).convert(),
&B::Device::default(),
)
}

pub(crate) fn clip_parameters(parameters: &Parameters) -> Vec<f32> {
pub fn clip_parameters(parameters: &Parameters) -> Vec<f32> {
// https://regex101.com/r/21mXNI/1
const CLAMPS: [(f32, f32); 17] = [
(S_MIN, INIT_S_MAX),
Expand Down

0 comments on commit 0d342dd

Please sign in to comment.