Skip to content

Commit

Permalink
Fix/filter_outlier outputs dataset in arbitrary order (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Sep 28, 2024
1 parent bc5d602 commit 324671d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.2.2"
version = "1.2.3"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
34 changes: 31 additions & 3 deletions src/batch_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,41 @@ where

#[cfg(test)]
mod tests {
use burn::backend::{ndarray::NdArrayDevice, NdArray};

use super::*;
use crate::{convertor_tests::anki21_sample_file_converted_to_fsrs, FSRSItem, FSRSReview};
use crate::{
convertor_tests::anki21_sample_file_converted_to_fsrs,
dataset::{prepare_training_data, FSRSBatcher, FSRSDataset},
FSRSItem, FSRSReview,
};

#[test]
fn batch_shuffle_dataloader() {
let train_set = anki21_sample_file_converted_to_fsrs();
let (_pre_train_set, train_set) = prepare_training_data(train_set);
let dataset = FSRSDataset::from(train_set);
let batch_size = 512;
let seed = 42;
let device = NdArrayDevice::Cpu;
type Backend = NdArray<f32>;
let batcher = FSRSBatcher::<Backend>::new(device);
let dataloader =
BatchShuffledDataLoaderBuilder::new(batcher).build(dataset, batch_size, seed);
let item = dataloader.iter().next().unwrap();
assert_eq!(
item.t_historys.shape(),
burn::tensor::Shape { dims: [6, 512] }
);
let item2 = dataloader.iter().next().unwrap();
assert_eq!(
item2.t_historys.shape(),
burn::tensor::Shape { dims: [4, 512] }
);
}

#[test]
fn batch_shuffle() {
use crate::dataset::FSRSDataset;
let dataset = Arc::new(FSRSDataset::from(anki21_sample_file_converted_to_fsrs()));
let batch_size = 10;
let seed = 42;
Expand Down Expand Up @@ -484,7 +513,6 @@ mod tests {

#[test]
fn item_shuffle() {
use crate::dataset::FSRSDataset;
use burn::data::dataset::transform::ShuffledDataset;
let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs());
let seed = 42;
Expand Down
3 changes: 2 additions & 1 deletion src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use burn::{
tensor::{backend::Backend, Data, ElementConversion, Float, Int, Shape, Tensor},
};

use itertools::Itertools;
use serde::{Deserialize, Serialize};

/// Stores a list of reviews for a card, in chronological order. Each FSRSItem corresponds
Expand Down Expand Up @@ -190,7 +191,7 @@ pub fn filter_outlier(
let mut filtered_items = vec![];
let mut removed_pairs: [HashSet<_>; 5] = Default::default();

for (rating, delta_t_groups) in groups.into_iter() {
for (rating, delta_t_groups) in groups.into_iter().sorted_by_key(|&(k, _)| k) {
let mut sub_groups = delta_t_groups.into_iter().collect::<Vec<_>>();

// order by size of sub group ascending and delta_t descending
Expand Down

0 comments on commit 324671d

Please sign in to comment.