Skip to content

Commit

Permalink
Convert Vec<Revlog> directly to FSRSItem
Browse files Browse the repository at this point in the history
  • Loading branch information
dae committed Aug 21, 2023
1 parent c608306 commit 4ec62fe
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 72 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ log = "0.4"
rusqlite = { version = "0.29.0" }
chrono = "0.4.26"
chrono-tz = "0.8.3"
itertools = "0.11.0"
111 changes: 40 additions & 71 deletions src/convertor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use chrono::prelude::*;
use chrono_tz::Tz;
use itertools::Itertools;
use rusqlite::{Connection, Result, Row};
use std::collections::HashMap;

Expand All @@ -12,9 +13,6 @@ struct RevlogEntry {
button_chosen: i32,
review_kind: i64,
delta_t: i32,
i: usize,
r_history: Vec<i32>,
t_history: Vec<i32>,
}

fn row_to_revlog_entry(row: &Row) -> Result<RevlogEntry> {
Expand All @@ -24,9 +22,6 @@ fn row_to_revlog_entry(row: &Row) -> Result<RevlogEntry> {
button_chosen: row.get(2)?,
review_kind: row.get(3).unwrap_or_default(),
delta_t: 0,
i: 0,
r_history: vec![],
t_history: vec![],
})
}

Expand Down Expand Up @@ -104,11 +99,11 @@ fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chr
datetime.date_naive()
}

fn extract_time_series_feature(
fn convert_to_fsrs_item(
mut entries: Vec<RevlogEntry>,
next_day_starts_at: i64,
timezone: Tz,
) -> Option<Vec<RevlogEntry>> {
) -> Option<FSRSItem> {
// Find the index of the first RevlogEntry in the last continuous group where review_kind = 0
// 寻找最后一组连续 review_kind = 0 的第一个 RevlogEntry 的索引
let mut index_to_keep = 0;
Expand All @@ -129,6 +124,15 @@ fn extract_time_series_feature(
// 删除此 RevlogEntry 之前的所有条目
entries.drain(..index_to_keep);

// we ignore cards that don't start in the learning state
if let Some(entry) = entries.first() {
if entry.review_kind != 0 {
return None;
}
} else {
return None;
}

// Increment review_kind of all entries by 1
// 将所有 review_kind + 1
for entry in &mut entries {
Expand All @@ -151,21 +155,6 @@ fn extract_time_series_feature(
entries[i].delta_t = (date_current - date_previous).num_days() as i32;
}

// Compute i, r_history, t_history
// 计算 i, r_history, t_history
for i in 0..entries.len() {
// Position starts from 1
// 位置从 1 开始
entries[i].i = i + 1;

// Except for the first entry, the remaining entries add the preceding button_chosen and delta_t to r_history and t_history
// 除了第一个条目,其余条目将前面的 button_chosen 和 delta_t 加入 r_history 和 t_history
if i > 0 {
entries[i].r_history = entries[0..i].iter().map(|e| e.button_chosen).collect();
entries[i].t_history = entries[0..i].iter().map(|e| e.delta_t).collect();
}
}

// Find the RevlogEntry with review_kind = 0 where the preceding RevlogEntry has review_kind of 1 or 2, then remove it and all following RevlogEntries
// 找到 review_kind = 0 且前一个 RevlogEntry 的 review_kind 是 1 或 2 的 RevlogEntry,然后删除其及其之后的所有 RevlogEntry
if let Some(index_to_remove) = entries.windows(2).enumerate().find_map(|(i, window)| {
Expand All @@ -183,52 +172,37 @@ fn extract_time_series_feature(
entries.truncate(index_to_remove);
}

// we ignore cards that don't start in the learning state
if let Some(first) = entries.first() {
if first.review_kind == 1 {
return Some(entries)
}
}
None
}

fn convert_to_fsrs_items(revlogs: Vec<Vec<RevlogEntry>>) -> Vec<FSRSItem> {
revlogs
.into_iter()
.flat_map(|group| {
group
.into_iter()
.filter(|entry| entry.i != 1) // 过滤掉 i = 1 的 RevlogEntry
.map(|entry| FSRSItem {
reviews: entry
.r_history
.iter()
.zip(entry.t_history.iter())
.map(|(&r, &t)| Review {
rating: r,
delta_t: t,
})
.collect(),
delta_t: entry.delta_t as f32,
label: match entry.button_chosen {
1 => 0.0,
2 | 3 | 4 => 1.0,
_ => panic!("Unexpected value for button_chosen"),
},
})
// Compute i, r_history, t_history
// 计算 i, r_history, t_history
// Except for the first entry, the remaining entries add the preceding button_chosen and delta_t to r_history and t_history
// 除了第一个条目,其余条目将前面的 button_chosen 和 delta_t 加入 r_history 和 t_history
let reviews = entries
.iter()
.skip(1)
.map(|entry| Review {
rating: entry.button_chosen,
delta_t: entry.delta_t,
})
.collect()
.collect_vec();

let last = entries.last().unwrap();
Some(FSRSItem {
reviews,
delta_t: last.delta_t as f32,
label: match last.button_chosen {
1 => 0.0,
2 | 3 | 4 | _ => 1.0,
},
})
}

pub fn anki_to_fsrs() -> Vec<FSRSItem> {
let revlogs = read_collection();
let revlogs_per_card = group_by_cid(revlogs);
let extracted_revlogs_per_card: Vec<Vec<RevlogEntry>> = revlogs_per_card
revlogs_per_card
.into_iter()
.filter_map(|entries| extract_time_series_feature(entries, 4, Tz::Asia__Shanghai))
.collect();

convert_to_fsrs_items(extracted_revlogs_per_card)
.filter_map(|entries| convert_to_fsrs_item(entries, 4, Tz::Asia__Shanghai))
.collect()
}

#[test]
Expand All @@ -237,15 +211,10 @@ fn test() {
dbg!(revlogs.len());
let revlogs_per_card = group_by_cid(revlogs);
dbg!(revlogs_per_card.len());
let mut extracted_revlogs_per_card: Vec<Vec<RevlogEntry>> = revlogs_per_card
let extracted_revlogs_per_card = revlogs_per_card
.into_iter()
.map(|entries| extract_time_series_feature(entries, 4, Tz::Asia__Shanghai))
.collect();
.flat_map(|entries| convert_to_fsrs_item(entries, 4, Tz::Asia__Shanghai))
.collect_vec();

dbg!(extracted_revlogs_per_card
.iter()
.map(|x| x.len())
.sum::<usize>());
let fsrs_items: Vec<FSRSItem> = convert_to_fsrs_items(extracted_revlogs_per_card);
dbg!(fsrs_items.len());
dbg!(extracted_revlogs_per_card);
}
4 changes: 3 additions & 1 deletion src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ fn test_from_anki() {
use burn::data::dataloader::Dataset;
use burn::data::dataset::InMemDataset;

let dataset = InMemDataset::<FSRSItem>::new(anki_to_fsrs());
let items = anki_to_fsrs();
dbg!(&items.len());
let dataset = InMemDataset::<FSRSItem>::new(items);
let item = dataset.get(704).unwrap();
dbg!(&item);

Expand Down

0 comments on commit 4ec62fe

Please sign in to comment.