diff --git a/Cargo.lock b/Cargo.lock index 8ba5c2d..263317a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,7 +1077,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "1.2.2" +version = "1.2.3" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 393f6e2..d193930 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "1.2.2" +version = "1.2.3" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index 389ae68..52cbbf9 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -263,12 +263,41 @@ where #[cfg(test)] mod tests { + use burn::backend::{ndarray::NdArrayDevice, NdArray}; + use super::*; - use crate::{convertor_tests::anki21_sample_file_converted_to_fsrs, FSRSItem, FSRSReview}; + use crate::{ + convertor_tests::anki21_sample_file_converted_to_fsrs, + dataset::{prepare_training_data, FSRSBatcher, FSRSDataset}, + FSRSItem, FSRSReview, + }; + + #[test] + fn batch_shuffle_dataloader() { + let train_set = anki21_sample_file_converted_to_fsrs(); + let (_pre_train_set, train_set) = prepare_training_data(train_set); + let dataset = FSRSDataset::from(train_set); + let batch_size = 512; + let seed = 42; + let device = NdArrayDevice::Cpu; + type Backend = NdArray; + let batcher = FSRSBatcher::::new(device); + let dataloader = + BatchShuffledDataLoaderBuilder::new(batcher).build(dataset, batch_size, seed); + let item = dataloader.iter().next().unwrap(); + assert_eq!( + item.t_historys.shape(), + burn::tensor::Shape { dims: [6, 512] } + ); + let item2 = dataloader.iter().next().unwrap(); + assert_eq!( + item2.t_historys.shape(), + burn::tensor::Shape { dims: [4, 512] } + ); + } #[test] fn batch_shuffle() { - use crate::dataset::FSRSDataset; let dataset = Arc::new(FSRSDataset::from(anki21_sample_file_converted_to_fsrs())); let batch_size = 10; let seed = 42; @@ -484,7 +513,6 @@ mod tests { #[test] fn item_shuffle() { - use crate::dataset::FSRSDataset; use burn::data::dataset::transform::ShuffledDataset; let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs()); let seed = 42; diff --git a/src/dataset.rs b/src/dataset.rs index 41d863f..08c49d3 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -6,6 +6,7 @@ use burn::{ tensor::{backend::Backend, Data, ElementConversion, Float, Int, Shape, Tensor}, }; +use itertools::Itertools; use serde::{Deserialize, Serialize}; /// Stores a list of reviews for a card, in chronological order. Each FSRSItem corresponds @@ -190,7 +191,7 @@ pub fn filter_outlier( let mut filtered_items = vec![]; let mut removed_pairs: [HashSet<_>; 5] = Default::default(); - for (rating, delta_t_groups) in groups.into_iter() { + for (rating, delta_t_groups) in groups.into_iter().sorted_by_key(|&(k, _)| k) { let mut sub_groups = delta_t_groups.into_iter().collect::>(); // order by size of sub group ascending and delta_t descending