Skip to content

Commit

Permalink
add assert_eq for test_batcher
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Aug 24, 2023
1 parent f868215 commit be65fa2
Showing 1 changed file with 153 additions and 10 deletions.
163 changes: 153 additions & 10 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,158 @@ fn test_batcher() {
type Backend = NdArrayBackend<f32>;
let device = NdArrayDevice::Cpu;
let batcher: FSRSBatcher<Backend> = FSRSBatcher::<Backend>::new(device);
let dataset = FSRSDataset::train();
let mut items = vec![];
for item in dataset.iter() {
items.push(item);
if items.len() >= 8 {
break;
}
}
dbg!(&items);
let items = vec![
FSRSItem {
reviews: vec![
FSRSReview {
rating: 4,
delta_t: 0,
},
FSRSReview {
rating: 3,
delta_t: 5,
},
],
},
FSRSItem {
reviews: vec![
FSRSReview {
rating: 4,
delta_t: 0,
},
FSRSReview {
rating: 3,
delta_t: 5,
},
FSRSReview {
rating: 3,
delta_t: 11,
},
],
},
FSRSItem {
reviews: vec![
FSRSReview {
rating: 4,
delta_t: 0,
},
FSRSReview {
rating: 3,
delta_t: 2,
},
],
},
FSRSItem {
reviews: vec![
FSRSReview {
rating: 4,
delta_t: 0,
},
FSRSReview {
rating: 3,
delta_t: 2,
},
FSRSReview {
rating: 3,
delta_t: 6,
},
],
},
FSRSItem {
reviews: vec![
FSRSReview {
rating: 4,
delta_t: 0,
},
FSRSReview {
rating: 3,
delta_t: 2,
},
FSRSReview {
rating: 3,
delta_t: 6,
},
FSRSReview {
rating: 3,
delta_t: 16,
},
],
},
FSRSItem {
reviews: vec![
FSRSReview {
rating: 4,
delta_t: 0,
},
FSRSReview {
rating: 3,
delta_t: 2,
},
FSRSReview {
rating: 3,
delta_t: 6,
},
FSRSReview {
rating: 3,
delta_t: 16,
},
FSRSReview {
rating: 3,
delta_t: 39,
},
],
},
FSRSItem {
reviews: vec![
FSRSReview {
rating: 1,
delta_t: 0,
},
FSRSReview {
rating: 1,
delta_t: 1,
},
],
},
FSRSItem {
reviews: vec![
FSRSReview {
rating: 1,
delta_t: 0,
},
FSRSReview {
rating: 1,
delta_t: 1,
},
FSRSReview {
rating: 3,
delta_t: 1,
},
],
},
];
let batch = batcher.batch(items);
dbg!(&batch);
assert_eq!(
batch.t_historys.to_data(),
Data::from([
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 5.0, 0.0, 2.0, 2.0, 2.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0, 0.0]
])
);
assert_eq!(
batch.r_historys.to_data(),
Data::from([
[4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0],
[0.0, 3.0, 0.0, 3.0, 3.0, 3.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0]
])
);
assert_eq!(
batch.delta_ts.to_data(),
Data::from([5.0, 11.0, 2.0, 6.0, 16.0, 39.0, 1.0, 1.0])
);
assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1]));
}

0 comments on commit be65fa2

Please sign in to comment.