diff --git a/Cargo.lock b/Cargo.lock index 95f004a..18a4e1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -179,9 +179,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burn" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e041d5f4eef703500763e599050cba419cd90d464172d71e3d5397baebbf1d8a" +checksum = "3960b57a6ad4baf54d1dba766965e4559c4b9a8f391107fee5de29db57265840" dependencies = [ "burn-core", "burn-train", @@ -189,9 +189,9 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e23c815bc728ac60343b8820fb71e9b4a2c0cb283bfd58828246caacabe6eff" +checksum = "cf9479c28bdce3f2b1541f0a9215628f6256b5f3d66871192a3c56d55171d28e" dependencies = [ "burn-common", "burn-tensor", @@ -202,9 +202,9 @@ dependencies = [ [[package]] name = "burn-candle" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d319a88254df7e9740154c32e862d721d29e5f782c0fdf7004f6b9ed5c8369f" +checksum = "d811c54fa6d9beb38808a1aabd9515c39090720cae572d54f25c041b1702e8fd" dependencies = [ "burn-tensor", "candle-core", @@ -214,9 +214,9 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a14cddb7f93dc985637e21f068a343acdfc4d62232fb11101f88c2739abad249" +checksum = "8d9540b2f45a2d337220e702d7a87572c8e1c78db91a200b22924a8c4a6e9be4" dependencies = [ "async-trait", "derive-new", @@ -235,9 +235,9 @@ dependencies = [ [[package]] name = "burn-compute" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbe641bbe653d04fb070a80946f3db13485e04d7d12104aab9287a1d55b3493c" +checksum = "3e890d8999b25a1a090c2afe198243fc79f0a299efb531a4871c084b0ab9fa11" dependencies = [ "burn-common", "derive-new", @@ -253,9 +253,9 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f3532e2f722bca39aefa69aea2b8e6cf2c3bf70f95ba8421b557082d89ea476" +checksum = "8af6bc0afe55a57ff0b08f52302df4e3d09f96805a4f1e15c521f1082cb02b4f" dependencies = [ "bincode", "burn-autodiff", @@ -281,9 +281,9 @@ dependencies = [ [[package]] name = "burn-dataset" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebb03147d7c50f31c673ee7f672543caddd56bc5de906810db23e396ca062054" +checksum = "3feae7766b56e947d38ac4d6903388270d848609339a147a513145703426f6db" dependencies = [ "csv", "derive-new", @@ -307,9 +307,9 @@ dependencies = [ [[package]] name = "burn-derive" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dbf7e7f4154821f1a74c709ed2191304701e6f56b6221aec8585b8a16d16ae5" +checksum = "8618ac2c171c7054ffd3ce8da15c3d4b11dc805eb393065c74c05882ef79d931" dependencies = [ "derive-new", "proc-macro2", @@ -319,9 +319,9 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "934015329ca3b41a6a6bc7b6a4eedcda04d899085e0b3273e7fb330358c15cf8" +checksum = "8d77b882d131a67d15f91b915fb3e0a5add73547e7352310d33c877fbe77c79e" dependencies = [ "burn-common", "burn-tensor", @@ -334,9 +334,9 @@ dependencies = [ [[package]] name = "burn-jit" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d257cec36c1b4404c79355492a0c32d0775ed5d7826241051323eb88f1e633dc" +checksum = "0cb62a93030a690c329b95c01b43e3064a4bd36031e9111d537641d36e42f3ac" dependencies = [ "burn-common", "burn-compute", @@ -354,9 +354,9 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f3a7d13e0116b4e442bda45aa9eb8a4cc3b70cf7d67197b13d539753275428c" +checksum = "05f40bb0b5938937a721045752f1ec1baee8a873429fd17e6e6f2155c6cdf33a" dependencies = [ "burn-autodiff", "burn-common", @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "burn-tch" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ee78099b81128ba1122c645344cb7126c1fadfc05b284150efd94731001f0a7" +checksum = "8cb9c2b547499a3d990e93b950965b9a478edfec4a7bf98d5d4412ff8c897129" dependencies = [ "burn-tensor", "half", @@ -386,9 +386,9 @@ dependencies = [ [[package]] name = "burn-tensor" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9395b25136b8fff2ca293dc30e8ca915cc811ed48ffbb147063b6c9c7fcba6a" +checksum = "bfa19c21f54e1a189be3bbaec45efafdf1c89b2763710b381c9f32ae25e7dbe8" dependencies = [ "burn-common", "derive-new", @@ -402,9 +402,9 @@ dependencies = [ [[package]] name = "burn-train" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95f83ed597cdb313fb0e18b389f88b96d5bcd1a37620adc969fe2934d486ff" +checksum = "0a0014ee82ef967bd82dda378cfaf340f255c39c729e29ac3bc65d3107e4c7ee" dependencies = [ "burn-common", "burn-core", @@ -418,9 +418,9 @@ dependencies = [ [[package]] name = "burn-wgpu" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6377670147b65387c807938b4f77a0b149b154ecc8b749f66ad068d345efac14" +checksum = "1575890471123109c6aeb725c52ac649fa9e0013e2303f57dc534d5e0cb857e5" dependencies = [ "burn-common", "burn-compute", @@ -1017,7 +1017,7 @@ checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" [[package]] name = "fsrs" -version = "1.4.0" +version = "1.4.3" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index c921bfe..a327d3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "1.4.0" +version = "1.4.3" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/README.md b/README.md index a08adf8..7b8402d 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ to `.git/hooks/pre-commit`, then `chmod +x .git/hooks/pre-commit` ## Bindings - python +- nodejs ## Q&A diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index fe4ec6c..9ed1501 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -1,70 +1,34 @@ -use burn::data::{ - dataloader::{ - batcher::DynBatcher, BatchStrategy, DataLoader, DataLoaderIterator, FixBatchStrategy, - Progress, - }, - dataset::Dataset, -}; +use std::sync::Mutex; -use rand::{distributions::Standard, prelude::SliceRandom, rngs::StdRng, Rng, SeedableRng}; -use std::{ - marker::PhantomData, - sync::{Arc, Mutex}, -}; +use burn::data::dataloader::batcher::Batcher; +use burn::data::dataloader::{DataLoaderIterator, Progress}; +use burn::prelude::Backend; +use rand::seq::SliceRandom; +use rand::SeedableRng; -use crate::{dataset::FSRSDataset, FSRSItem}; +use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset}; -pub(crate) struct BatchShuffledDataset { - dataset: Arc, - indices: Vec, - input: PhantomData, +#[derive(Clone)] +pub(crate) struct BatchTensorDataset { + dataset: Vec>, } -impl BatchShuffledDataset { +impl BatchTensorDataset { /// Creates a new shuffled dataset. - pub fn new(dataset: Arc, batch_size: usize, rng: &mut StdRng) -> Self { - let len = dataset.len(); - - // Calculate the number of batches - // 计算批数 - let num_batches = (len + batch_size - 1) / batch_size; - - // Create a vector of batch indices and shuffle it - // 创建一个批数索引的向量并打乱 - let mut batch_indices: Vec<_> = (0..num_batches).collect(); - batch_indices.shuffle(rng); - // info!("batch_indices: {:?}", &batch_indices); - // Generate the corresponding item indices for each shuffled batch - // 为每个打乱的批次生成相应的元素索引 - let mut indices = vec![]; - for batch_index in batch_indices { - let start_index = batch_index * batch_size; - let end_index = (start_index + batch_size).min(len); - indices.extend(start_index..end_index); - } - // info!("indices: {:?}", &indices); - Self { - dataset, - indices, - input: PhantomData, - } - } - - /// Creates a new shuffled dataset with a fixed seed. - pub fn with_seed(dataset: Arc, batch_size: usize, seed: u64) -> Self { - let mut rng = StdRng::seed_from_u64(seed); - Self::new(dataset, batch_size, &mut rng) + pub fn new(dataset: FSRSDataset, batch_size: usize, device: B::Device) -> Self { + let batcher = FSRSBatcher::::new(device); + let dataset = dataset + .items + .chunks(batch_size) + .map(|items| batcher.batch(items.to_vec())) + .collect(); + Self { dataset } } } -impl Dataset for BatchShuffledDataset { - fn get(&self, index: usize) -> Option { - let shuffled_index = self.indices.get(index)?; - // info!( - // "original index: {}, shuffled index: {}", - // index, shuffled_index - // ); - self.dataset.get(*shuffled_index) +impl BatchTensorDataset { + fn get(&self, index: usize) -> Option> { + self.dataset.get(index).cloned() } fn len(&self) -> usize { @@ -72,129 +36,36 @@ impl Dataset for BatchShuffledDataset { } } -/// A data loader that can be used to iterate over a dataset in batches. -pub(crate) struct BatchShuffledDataLoader { - strategy: Box>, - dataset: Arc, - batcher: Box>, +pub struct ShuffleDataLoader { + dataset: BatchTensorDataset, rng: Mutex, - batch_size: usize, } -impl BatchShuffledDataLoader { - /// Creates a new batch data loader. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// * `rng` - The rng determining if the dataset is shuffled each time a dataloader - /// iterator is created. - /// - /// # Returns - /// - /// The batch data loader. - pub fn new( - strategy: Box>, - dataset: Arc, - batcher: Box>, - rng: rand::rngs::StdRng, - batch_size: usize, - ) -> Self { +impl ShuffleDataLoader { + pub fn new(dataset: BatchTensorDataset, seed: u64) -> Self { Self { - strategy, dataset, - batcher, - rng: Mutex::new(rng), - batch_size, + rng: Mutex::new(rand::rngs::StdRng::seed_from_u64(seed)), } } } -/// A data loader iterator that can be used to iterate over a data loader. -struct BatchShuffledDataloaderIterator { +pub(crate) struct ShuffleDataLoaderIterator { current_index: usize, - strategy: Box>, - dataset: Arc>, - batcher: Box>, -} - -impl DataLoader for BatchShuffledDataLoader -where - BatchShuffledDataset: Dataset, -{ - fn iter<'a>(&'a self) -> Box + 'a> { - // When starting a new iteration, we first check if the dataloader was created with an rng, - // implying that we should shuffle the dataset beforehand, while advancing the current - // rng to ensure that each new iteration shuffles the dataset differently. - let dataset = Arc::new(BatchShuffledDataset::with_seed( - self.dataset.clone(), - self.batch_size, - self.rng.lock().unwrap().sample(Standard), - )); - Box::new(BatchShuffledDataloaderIterator::new( - self.strategy.clone_dyn(), - dataset, - self.batcher.clone_dyn(), - )) - } - - fn num_items(&self) -> usize { - self.dataset.len() - } + indices: Vec, + dataset: BatchTensorDataset, } -impl BatchShuffledDataloaderIterator -where - BatchShuffledDataset: Dataset, -{ - /// Creates a new batch data loader iterator. - /// - /// # Arguments - /// - /// * `strategy` - The batch strategy. - /// * `dataset` - The dataset. - /// * `batcher` - The batcher. - /// - /// # Returns - /// - /// The batch data loader iterator. - pub fn new( - strategy: Box>, - dataset: Arc>, - batcher: Box>, - ) -> Self { +impl ShuffleDataLoaderIterator { + pub(crate) fn new(dataset: BatchTensorDataset, indices: Vec) -> Self { Self { current_index: 0, - strategy, + indices, dataset, - batcher, - } - } -} - -impl Iterator for BatchShuffledDataloaderIterator { - type Item = O; - - fn next(&mut self) -> Option { - while let Some(item) = self.dataset.get(self.current_index) { - self.current_index += 1; - self.strategy.add(item); - - if let Some(items) = self.strategy.batch(false) { - return Some(self.batcher.batch(items)); - } } - - let items = self.strategy.batch(true)?; - - Some(self.batcher.batch(items)) } -} -impl DataLoaderIterator for BatchShuffledDataloaderIterator { - fn progress(&self) -> Progress { + pub(crate) fn progress(&self) -> Progress { Progress { items_processed: self.current_index, items_total: self.dataset.len(), @@ -202,323 +73,106 @@ impl DataLoaderIterator for BatchShuffledDataloaderIterator { } } -/// A builder for data loaders. -pub struct BatchShuffledDataLoaderBuilder { - batcher: Box>, -} +impl Iterator for ShuffleDataLoaderIterator { + type Item = FSRSBatch; -impl BatchShuffledDataLoaderBuilder -where - I: Send + Sync + Clone + std::fmt::Debug + 'static, - O: Send + Clone + std::fmt::Debug + 'static, - BatchShuffledDataset: Dataset, -{ - /// Creates a new data loader builder. - /// - /// # Arguments - /// - /// * `batcher` - The batcher. - /// - /// # Returns - /// - /// The data loader builder. - pub fn new(batcher: B) -> Self - where - B: DynBatcher + 'static, - { - Self { - batcher: Box::new(batcher), + fn next(&mut self) -> Option { + if let Some(index) = self.indices.get(self.current_index) { + self.current_index += 1; + return self.dataset.get(*index); } + None } +} - /// Builds the data loader. - /// - /// # Arguments - /// - /// * `dataset` - The dataset. - /// - /// # Returns - /// - /// The data loader. - pub fn build( - self, - dataset: FSRSDataset, - batch_size: usize, - seed: u64, - ) -> Arc> { - let dataset = Arc::new(dataset); - - let rng = StdRng::seed_from_u64(seed); - let strategy = Box::new(FixBatchStrategy::new(batch_size)); +impl DataLoaderIterator> for ShuffleDataLoaderIterator { + fn progress(&self) -> Progress { + Progress::new(self.current_index, self.dataset.len()) + } +} - Arc::new(BatchShuffledDataLoader::new( - strategy, - dataset, - self.batcher, - rng, - batch_size, - )) +impl ShuffleDataLoader { + pub(crate) fn iter(&self) -> ShuffleDataLoaderIterator { + let mut indices: Vec<_> = (0..self.dataset.len()).collect(); + indices.shuffle(&mut *self.rng.lock().unwrap()); + ShuffleDataLoaderIterator::new(self.dataset.clone(), indices) } } #[cfg(test)] mod tests { - use burn::backend::{ndarray::NdArrayDevice, NdArray}; + use burn::{ + backend::{ndarray::NdArrayDevice, NdArray}, + tensor::Shape, + }; use super::*; use crate::{ - convertor_tests::anki21_sample_file_converted_to_fsrs, - dataset::{prepare_training_data, FSRSBatcher, FSRSDataset}, - FSRSItem, FSRSReview, + convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::prepare_training_data, }; #[test] - fn batch_shuffle_dataloader() { + fn test_simple_dataloader() { let train_set = anki21_sample_file_converted_to_fsrs(); let (_pre_train_set, train_set) = prepare_training_data(train_set); let dataset = FSRSDataset::from(train_set); let batch_size = 512; - let seed = 42; + let seed = 114514; let device = NdArrayDevice::Cpu; type Backend = NdArray; - let batcher = FSRSBatcher::::new(device); - let dataloader = - BatchShuffledDataLoaderBuilder::new(batcher).build(dataset, batch_size, seed); - let item = dataloader.iter().next().unwrap(); + + let dataset = BatchTensorDataset::::new(dataset, batch_size, device); + let dataloader = ShuffleDataLoader::new(dataset, seed); + let mut iterator = dataloader.iter(); + // dbg!(&iterator.indices); + let batch = iterator.next().unwrap(); assert_eq!( - item.t_historys.shape(), - burn::tensor::Shape { dims: [6, 512] } + batch.t_historys.shape(), + Shape { + dims: [7, batch_size] + } ); - let item2 = dataloader.iter().next().unwrap(); + let batch = iterator.next().unwrap(); assert_eq!( - item2.t_historys.shape(), - burn::tensor::Shape { dims: [4, 512] } + batch.t_historys.shape(), + Shape { + dims: [6, batch_size] + } ); - } - #[test] - fn batch_shuffle() { - let dataset = Arc::new(FSRSDataset::from(anki21_sample_file_converted_to_fsrs())); - let batch_size = 10; - let seed = 42; - let batch_shuffled_dataset = BatchShuffledDataset::with_seed(dataset, batch_size, seed); + let lengths = iterator + .map(|batch| batch.t_historys.shape().dims[0]) + .collect::>(); assert_eq!( - (0..batch_shuffled_dataset.len().min(batch_size)) - .map(|i| batch_shuffled_dataset.get(i).unwrap()) - .collect::>(), - [ - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 1, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 2 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 1, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 1, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 3 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 1 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - }, - FSRSReview { - rating: 3, - delta_t: 2 - } - ] - }, - FSRSItem { - reviews: vec![ - FSRSReview { - rating: 3, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 0 - }, - FSRSReview { - rating: 4, - delta_t: 1 - } - ] - } + lengths, + vec![ + 48, 6, 8, 5, 11, 5, 10, 19, 6, 13, 9, 6, 5, 3, 9, 6, 3, 13, 7, 5, 4, 4, 4, 6, 4, 3, ] ); - } - #[test] - fn item_shuffle() { - use burn::data::dataset::transform::ShuffledDataset; - let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs()); - let seed = 42; - let shuffled_dataset = ShuffledDataset::with_seed(dataset, seed); - for i in 0..shuffled_dataset.len().min(10) { - dbg!(shuffled_dataset.get(i).unwrap()); - } + let mut iterator = dataloader.iter(); + // dbg!(&iterator.indices); + let batch = iterator.next().unwrap(); + assert_eq!( + batch.t_historys.shape(), + Shape { + dims: [19, batch_size] + } + ); + let batch = iterator.next().unwrap(); + assert_eq!( + batch.t_historys.shape(), + Shape { + dims: [9, batch_size] + } + ); + + let lengths = iterator + .map(|batch| batch.t_historys.shape().dims[0]) + .collect::>(); + assert_eq!( + lengths, + vec![3, 11, 3, 6, 6, 6, 5, 5, 7, 6, 4, 9, 10, 4, 48, 3, 4, 5, 13, 13, 7, 5, 4, 8, 6, 6] + ); } } diff --git a/src/dataset.rs b/src/dataset.rs index 1bd4416..043b143 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; /// first one. /// When used during review, the last item should include the correct delta_t, but /// the provided rating is ignored as all four ratings are returned by .next_states() -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Default)] pub struct FSRSItem { pub reviews: Vec, } @@ -51,7 +51,7 @@ impl FSRSItem { .reviews .iter() .find(|review| review.delta_t > 0) - .unwrap() + .expect("Invalid FSRS item: at least one review with delta_t > 0 is required") } pub(crate) fn r_matrix_index(&self) -> (u32, u32, u32) { @@ -107,15 +107,25 @@ impl Batcher> for FSRSBatcher { delta_t.resize(pad_size, 0); rating.resize(pad_size, 0); let delta_t = Tensor::from_data( - Data::new(delta_t, Shape { dims: [pad_size] }).convert(), + Data::new( + delta_t, + Shape { + dims: [1, pad_size], + }, + ) + .convert(), &self.device, - ) - .unsqueeze(); + ); let rating = Tensor::from_data( - Data::new(rating, Shape { dims: [pad_size] }).convert(), + Data::new( + rating, + Shape { + dims: [1, pad_size], + }, + ) + .convert(), &self.device, - ) - .unsqueeze(); + ); (delta_t, rating) }) .unzip(); @@ -156,7 +166,7 @@ impl Batcher> for FSRSBatcher { } pub(crate) struct FSRSDataset { - items: Vec, + pub(crate) items: Vec, } impl Dataset for FSRSDataset { diff --git a/src/inference.rs b/src/inference.rs index 9e13f4b..5ae7388 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -18,6 +18,7 @@ pub(crate) const DECAY: f64 = -0.5; /// (9/10) ^ (1 / DECAY) - 1 pub(crate) const FACTOR: f64 = 19f64 / 81f64; pub(crate) const S_MIN: f32 = 0.01; +pub(crate) const S_MAX: f32 = 36500.0; /// This is a slice for efficiency, but should always be 17 in length. pub type Parameters = [f32]; use itertools::izip; diff --git a/src/model.rs b/src/model.rs index 7e6c067..4e5f172 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,5 +1,5 @@ use crate::error::{FSRSError, Result}; -use crate::inference::{Parameters, DECAY, FACTOR, S_MIN}; +use crate::inference::{Parameters, DECAY, FACTOR, S_MAX, S_MIN}; use crate::parameter_clipper::clip_parameters; use crate::DEFAULT_PARAMETERS; use burn::backend::ndarray::NdArrayDevice; @@ -163,7 +163,7 @@ impl Model { ) }; MemoryStateTensors { - stability: new_s.clamp(S_MIN, 36500.0), + stability: new_s.clamp(S_MIN, S_MAX), difficulty: new_d, } } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index fc381ce..09e3941 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -1,6 +1,7 @@ use crate::error::{FSRSError, Result}; -use crate::inference::{next_interval, ItemProgress, Parameters, DECAY, FACTOR, S_MIN}; +use crate::inference::{next_interval, ItemProgress, Parameters, DECAY, FACTOR, S_MAX, S_MIN}; use crate::model::check_and_fill_parameters; +use crate::parameter_clipper::clip_parameters; use crate::FSRS; use burn::tensor::backend::Backend; use itertools::{izip, Itertools}; @@ -71,12 +72,13 @@ impl Default for SimulatorConfig { fn stability_after_success(w: &[f32], s: f32, r: f32, d: f32, rating: usize) -> f32 { let hard_penalty = if rating == 2 { w[15] } else { 1.0 }; let easy_bonus = if rating == 4 { w[16] } else { 1.0 }; - s * (f32::exp(w[8]) + (s * (f32::exp(w[8]) * (11.0 - d) * s.powf(-w[9]) * (f32::exp((1.0 - r) * w[10]) - 1.0) * hard_penalty) - .mul_add(easy_bonus, 1.0) + .mul_add(easy_bonus, 1.0)) + .clamp(S_MIN, S_MAX) } fn stability_after_failure(w: &[f32], s: f32, r: f32, d: f32) -> f32 { @@ -85,7 +87,7 @@ fn stability_after_failure(w: &[f32], s: f32, r: f32, d: f32) -> f32 { } fn stability_short_term(w: &[f32], s: f32, rating_offset: f32, session_len: f32) -> f32 { - s * (w[17] * (rating_offset + session_len * w[18])).exp() + (s * (w[17] * (rating_offset + session_len * w[18])).exp()).clamp(S_MIN, S_MAX) } fn init_d(w: &[f32], rating: usize) -> f32 { @@ -132,6 +134,7 @@ pub fn simulate( existing_cards: Option>, ) -> Result<(Array1, Array1, Array1, Array1), FSRSError> { let w = &check_and_fill_parameters(w)?; + let w = &clip_parameters(w); let SimulatorConfig { deck_size, learn_span, diff --git a/src/parameter_clipper.rs b/src/parameter_clipper.rs index d84c0e9..641adc0 100644 --- a/src/parameter_clipper.rs +++ b/src/parameter_clipper.rs @@ -2,13 +2,23 @@ use crate::{ inference::{Parameters, S_MIN}, pre_training::INIT_S_MAX, }; -use burn::tensor::{backend::Backend, Data, Tensor}; +use burn::{ + module::Param, + tensor::{backend::Backend, Data, Tensor}, +}; -pub(crate) fn parameter_clipper(parameters: Tensor) -> Tensor { - let val = clip_parameters(¶meters.to_data().convert().value); - Tensor::from_data( - Data::new(val, parameters.shape()).convert(), - &B::Device::default(), +pub(crate) fn parameter_clipper( + parameters: Param>, +) -> Param> { + let (id, val) = parameters.consume(); + let clipped = clip_parameters(&val.to_data().convert().value); + Param::initialized( + id, + Tensor::from_data( + Data::new(clipped, val.shape()).convert(), + &B::Device::default(), + ) + .require_grad(), ) } @@ -58,7 +68,7 @@ mod tests { &device, ); - let param: Tensor<1> = parameter_clipper(tensor); + let param = parameter_clipper(Param::from_tensor(tensor)); let values = ¶m.to_data().value; assert_eq!( diff --git a/src/training.rs b/src/training.rs index 132a049..995c164 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,6 +1,6 @@ -use crate::batch_shuffle::BatchShuffledDataLoaderBuilder; +use crate::batch_shuffle::{BatchTensorDataset, ShuffleDataLoader}; use crate::cosine_annealing::CosineAnnealingLR; -use crate::dataset::{prepare_training_data, FSRSBatcher, FSRSDataset, FSRSItem}; +use crate::dataset::{prepare_training_data, FSRSDataset, FSRSItem}; use crate::error::Result; use crate::model::{Model, ModelConfig}; use crate::parameter_clipper::parameter_clipper; @@ -9,7 +9,6 @@ use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS}; use burn::backend::Autodiff; use wasm_bindgen::prelude::*; -use burn::data::dataloader::DataLoaderBuilder; use burn::lr_scheduler::LrScheduler; use burn::module::AutodiffModule; use burn::nn::loss::Reduction; @@ -19,7 +18,7 @@ use burn::tensor::backend::Backend; use burn::tensor::{Int, Tensor}; use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress}; use burn::train::TrainingInterrupter; -use burn::{config::Config, module::Param, tensor::backend::AutodiffBackend}; +use burn::{config::Config, tensor::backend::AutodiffBackend}; use core::marker::PhantomData; use log::info; @@ -199,10 +198,12 @@ pub(crate) struct TrainingConfig { pub num_epochs: usize, #[config(default = 512)] pub batch_size: usize, - #[config(default = 42)] + #[config(default = 2023)] pub seed: u64, #[config(default = 4e-2)] pub learning_rate: f64, + #[config(default = 64)] + pub max_seq_len: usize, } pub fn calculate_average_recall(items: &[FSRSItem]) -> f32 { @@ -237,7 +238,7 @@ impl FSRS { }; let average_recall = calculate_average_recall(&train_set); - let (pre_train_set, train_set) = prepare_training_data(train_set); + let (pre_train_set, mut train_set) = prepare_training_data(train_set); if train_set.len() < 8 { finish_progress(); return Ok(DEFAULT_PARAMETERS.to_vec()); @@ -256,14 +257,15 @@ impl FSRS { finish_progress(); return Ok(pretrained_parameters); } - let config = TrainingConfig::new( ModelConfig { freeze_stability: false, initial_stability: Some(initial_stability), }, - AdamConfig::new(), + AdamConfig::new().with_epsilon(1e-8), ); + train_set.retain(|item| item.reviews.len() <= config.max_seq_len); + train_set.sort_by_cached_key(|item| item.reviews.len()); if let Some(progress) = &progress { let progress_state = ProgressState { @@ -318,7 +320,7 @@ impl FSRS { Ok(optimized_parameters) } - pub fn benchmark(&self, train_set: Vec) -> Vec { + pub fn benchmark(&self, mut train_set: Vec) -> Vec { let average_recall = calculate_average_recall(&train_set); let (pre_train_set, _next_train_set) = train_set .clone() @@ -330,8 +332,10 @@ impl FSRS { freeze_stability: false, initial_stability: Some(initial_stability), }, - AdamConfig::new(), + AdamConfig::new().with_epsilon(1e-8), ); + train_set.retain(|item| item.reviews.len() <= config.max_seq_len); + train_set.sort_by_cached_key(|item| item.reviews.len()); let model = train::>(train_set.clone(), train_set, &config, self.device(), None); let parameters: Vec = model.unwrap().w.val().to_data().convert().value; @@ -350,17 +354,19 @@ fn train( // Training data let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs; - let batcher_train = FSRSBatcher::::new(device.clone()); - let dataloader_train = BatchShuffledDataLoaderBuilder::new(batcher_train).build( + let batch_dataset = BatchTensorDataset::::new( FSRSDataset::from(train_set), config.batch_size, - config.seed, + device.clone(), ); + let dataloader_train = ShuffleDataLoader::new(batch_dataset, config.seed); - let batcher_valid = FSRSBatcher::new(device); - let dataloader_valid = DataLoaderBuilder::new(batcher_valid) - .batch_size(config.batch_size) - .build(FSRSDataset::from(test_set.clone())); + let batch_dataset = BatchTensorDataset::::new( + FSRSDataset::from(test_set.clone()), + config.batch_size, + device, + ); + let dataloader_valid = ShuffleDataLoader::new(batch_dataset, config.seed); let mut lr_scheduler = CosineAnnealingLR::init(iterations as f64, config.learning_rate); let interrupter = TrainingInterrupter::new(); @@ -389,7 +395,7 @@ fn train( item.r_historys, item.delta_ts, item.labels, - Reduction::Mean, + Reduction::Sum, ); let mut gradients = loss.backward(); if model.config.freeze_stability { @@ -397,7 +403,7 @@ fn train( } let grads = GradientsParams::from_grads(gradients, &model); model = optim.step(lr, model, grads); - model.w = Param::from_tensor(parameter_clipper(model.w.val())); + model.w = parameter_clipper(model.w); // info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr); renderer.render_train(TrainingProgress { progress, @@ -489,7 +495,8 @@ mod tests { let config = ModelConfig::default(); let device = NdArrayDevice::Cpu; - let model: Model>> = config.init(); + type B = Autodiff>; + let mut model: Model = config.init(); let item = FSRSBatch { t_historys: Tensor::from_floats( @@ -533,7 +540,6 @@ mod tests { let gradients = loss.backward(); let w_grad = model.w.grad(&gradients).unwrap(); - dbg!(&w_grad); Data::from([ -0.05832, -0.00682, -0.00255, 0.010539, -0.05128, 1.364291, 0.083658, -0.95023, @@ -541,6 +547,109 @@ mod tests { 0.202374, 0.214104, 0.032307, ]) .assert_approx_eq(&w_grad.clone().into_data(), 5); + + let config = + TrainingConfig::new(ModelConfig::default(), AdamConfig::new().with_epsilon(1e-8)); + let mut optim = config.optimizer.init::>(); + let lr = 0.04; + let grads = GradientsParams::from_grads(gradients, &model); + model = optim.step(lr, model, grads); + model.w = parameter_clipper(model.w); + assert_eq!( + model.w.val().to_data(), + Data::from([ + 0.44255, 1.22385, 3.2129998, 15.65105, 7.2349, 0.4945, 1.4204, 0.0446, 1.5057501, + 0.1592, 0.97925, 1.9794999, 0.07000001, 0.33605, 2.3097994, 0.2715, 2.9498, + 0.47655, 0.62210006 + ]) + ); + + let item = FSRSBatch { + t_historys: Tensor::from_floats( + Data::from([ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 1.0, 3.0], + [1.0, 3.0, 3.0, 5.0], + [3.0, 6.0, 6.0, 12.0], + ]), + &device, + ), + r_historys: Tensor::from_floats( + Data::from([ + [1.0, 2.0, 3.0, 4.0], + [3.0, 4.0, 2.0, 4.0], + [1.0, 4.0, 4.0, 3.0], + [4.0, 3.0, 3.0, 3.0], + [3.0, 1.0, 3.0, 3.0], + [2.0, 3.0, 3.0, 4.0], + ]), + &device, + ), + delta_ts: Tensor::from_floats(Data::from([4.0, 11.0, 12.0, 23.0]), &device), + labels: Tensor::from_ints(Data::from([1, 1, 1, 0]), &device), + }; + + let loss = model.forward_classification( + item.t_historys, + item.r_historys, + item.delta_ts, + item.labels, + Reduction::Sum, + ); + assert_eq!(loss.clone().into_data().convert::().value[0], 4.176347); + let gradients = loss.backward(); + let w_grad = model.w.grad(&gradients).unwrap(); + Data::from([ + -0.0401341, + -0.0061790533, + -0.00288913, + 0.01216853, + -0.05624995, + 1.147413, + 0.068084724, + -0.6906936, + 0.48760873, + -2.5428302, + 0.49044546, + -0.011574259, + 0.037729632, + -0.09633919, + -0.0009513022, + -0.12789416, + 0.19088513, + 0.2574597, + 0.049311582, + ]) + .assert_approx_eq(&w_grad.clone().into_data(), 5); + let grads = GradientsParams::from_grads(gradients, &model); + model = optim.step(lr, model, grads); + model.w = parameter_clipper(model.w); + assert_eq!( + model.w.val().to_data(), + Data::from([ + 0.48150504, + 1.2636971, + 3.2530522, + 15.611003, + 7.2749534, + 0.45482785, + 1.3808222, + 0.083782874, + 1.4658877, + 0.19898315, + 0.9393105, + 2.0193, + 0.030164223, + 0.37562984, + 2.3498251, + 0.3112984, + 2.909878, + 0.43652722, + 0.5825156 + ]) + ); } #[test]