Skip to content

Commit

Permalink
Add a test to check tensor conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
dae committed Aug 21, 2023
1 parent c35e770 commit 432066f
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/convertor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ pub fn anki_to_fsrs() -> Vec<FSRSItem> {
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::FSRSBatcher;
use burn::data::dataloader::batcher::Batcher;
use burn::tensor::Data;

// This test currently expects the following .anki21 file to be placed in tests/data/:
// https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip
Expand Down Expand Up @@ -242,7 +245,7 @@ mod tests {
);

// convert a subset and check it matches expectations
let fsrs_items = single_card_revlog
let mut fsrs_items = single_card_revlog
.into_iter()
.filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai))
.flatten()
Expand Down Expand Up @@ -340,5 +343,22 @@ mod tests {
}
]
);

use burn_ndarray::NdArrayDevice;
let device = NdArrayDevice::Cpu;
use burn_ndarray::NdArrayBackend;
type Backend = NdArrayBackend<f32>;
let batcher = FSRSBatcher::<Backend>::new(device);
let res = batcher.batch(vec![fsrs_items.pop().unwrap()]);
assert_eq!(res.delta_ts.into_scalar(), 64.0);
assert_eq!(
res.r_historys.squeeze(1).to_data(),
Data::from([3.0, 3.0, 3.0, 3.0, 2.0])
);
assert_eq!(
res.t_historys.squeeze(1).to_data(),
Data::from([0.0, 5.0, 10.0, 22.0, 56.0])
);
assert_eq!(res.labels.to_data(), Data::from([1]));
}
}

0 comments on commit 432066f

Please sign in to comment.