diff --git a/Cargo.toml b/Cargo.toml index 38c0964..b1d1930 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,4 @@ log = "0.4" rusqlite = { version = "0.29.0" } chrono = "0.4.26" chrono-tz = "0.8.3" +itertools = "0.11.0" \ No newline at end of file diff --git a/src/convertor.rs b/src/convertor.rs index 9641b2a..39eef68 100644 --- a/src/convertor.rs +++ b/src/convertor.rs @@ -1,5 +1,6 @@ use chrono::prelude::*; use chrono_tz::Tz; +use itertools::Itertools; use rusqlite::{Connection, Result, Row}; use std::collections::HashMap; @@ -12,9 +13,6 @@ struct RevlogEntry { button_chosen: i32, review_kind: i64, delta_t: i32, - i: usize, - r_history: Vec, - t_history: Vec, } fn row_to_revlog_entry(row: &Row) -> Result { @@ -24,9 +22,6 @@ fn row_to_revlog_entry(row: &Row) -> Result { button_chosen: row.get(2)?, review_kind: row.get(3).unwrap_or_default(), delta_t: 0, - i: 0, - r_history: vec![], - t_history: vec![], }) } @@ -104,11 +99,11 @@ fn convert_to_date(timestamp: i64, next_day_starts_at: i64, timezone: Tz) -> chr datetime.date_naive() } -fn extract_time_series_feature( +fn convert_to_fsrs_item( mut entries: Vec, next_day_starts_at: i64, timezone: Tz, -) -> Option> { +) -> Option { // Find the index of the first RevlogEntry in the last continuous group where review_kind = 0 // 寻找最后一组连续 review_kind = 0 的第一个 RevlogEntry 的索引 let mut index_to_keep = 0; @@ -129,6 +124,15 @@ fn extract_time_series_feature( // 删除此 RevlogEntry 之前的所有条目 entries.drain(..index_to_keep); + // we ignore cards that don't start in the learning state + if let Some(entry) = entries.first() { + if entry.review_kind != 0 { + return None; + } + } else { + return None; + } + // Increment review_kind of all entries by 1 // 将所有 review_kind + 1 for entry in &mut entries { @@ -151,21 +155,6 @@ fn extract_time_series_feature( entries[i].delta_t = (date_current - date_previous).num_days() as i32; } - // Compute i, r_history, t_history - // 计算 i, r_history, t_history - for i in 0..entries.len() { - // Position starts from 1 - // 位置从 1 开始 - entries[i].i = i + 1; - - // Except for the first entry, the remaining entries add the preceding button_chosen and delta_t to r_history and t_history - // 除了第一个条目,其余条目将前面的 button_chosen 和 delta_t 加入 r_history 和 t_history - if i > 0 { - entries[i].r_history = entries[0..i].iter().map(|e| e.button_chosen).collect(); - entries[i].t_history = entries[0..i].iter().map(|e| e.delta_t).collect(); - } - } - // Find the RevlogEntry with review_kind = 0 where the preceding RevlogEntry has review_kind of 1 or 2, then remove it and all following RevlogEntries // 找到 review_kind = 0 且前一个 RevlogEntry 的 review_kind 是 1 或 2 的 RevlogEntry,然后删除其及其之后的所有 RevlogEntry if let Some(index_to_remove) = entries.windows(2).enumerate().find_map(|(i, window)| { @@ -183,52 +172,37 @@ fn extract_time_series_feature( entries.truncate(index_to_remove); } - // we ignore cards that don't start in the learning state - if let Some(first) = entries.first() { - if first.review_kind == 1 { - return Some(entries) - } - } - None -} - -fn convert_to_fsrs_items(revlogs: Vec>) -> Vec { - revlogs - .into_iter() - .flat_map(|group| { - group - .into_iter() - .filter(|entry| entry.i != 1) // 过滤掉 i = 1 的 RevlogEntry - .map(|entry| FSRSItem { - reviews: entry - .r_history - .iter() - .zip(entry.t_history.iter()) - .map(|(&r, &t)| Review { - rating: r, - delta_t: t, - }) - .collect(), - delta_t: entry.delta_t as f32, - label: match entry.button_chosen { - 1 => 0.0, - 2 | 3 | 4 => 1.0, - _ => panic!("Unexpected value for button_chosen"), - }, - }) + // Compute i, r_history, t_history + // 计算 i, r_history, t_history + // Except for the first entry, the remaining entries add the preceding button_chosen and delta_t to r_history and t_history + // 除了第一个条目,其余条目将前面的 button_chosen 和 delta_t 加入 r_history 和 t_history + let reviews = entries + .iter() + .skip(1) + .map(|entry| Review { + rating: entry.button_chosen, + delta_t: entry.delta_t, }) - .collect() + .collect_vec(); + + let last = entries.last().unwrap(); + Some(FSRSItem { + reviews, + delta_t: last.delta_t as f32, + label: match last.button_chosen { + 1 => 0.0, + 2 | 3 | 4 | _ => 1.0, + }, + }) } pub fn anki_to_fsrs() -> Vec { let revlogs = read_collection(); let revlogs_per_card = group_by_cid(revlogs); - let extracted_revlogs_per_card: Vec> = revlogs_per_card + revlogs_per_card .into_iter() - .filter_map(|entries| extract_time_series_feature(entries, 4, Tz::Asia__Shanghai)) - .collect(); - - convert_to_fsrs_items(extracted_revlogs_per_card) + .filter_map(|entries| convert_to_fsrs_item(entries, 4, Tz::Asia__Shanghai)) + .collect() } #[test] @@ -237,15 +211,10 @@ fn test() { dbg!(revlogs.len()); let revlogs_per_card = group_by_cid(revlogs); dbg!(revlogs_per_card.len()); - let mut extracted_revlogs_per_card: Vec> = revlogs_per_card + let extracted_revlogs_per_card = revlogs_per_card .into_iter() - .map(|entries| extract_time_series_feature(entries, 4, Tz::Asia__Shanghai)) - .collect(); + .flat_map(|entries| convert_to_fsrs_item(entries, 4, Tz::Asia__Shanghai)) + .collect_vec(); - dbg!(extracted_revlogs_per_card - .iter() - .map(|x| x.len()) - .sum::()); - let fsrs_items: Vec = convert_to_fsrs_items(extracted_revlogs_per_card); - dbg!(fsrs_items.len()); + dbg!(extracted_revlogs_per_card); } diff --git a/src/dataset.rs b/src/dataset.rs index 9fd1793..69933dd 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -159,7 +159,9 @@ fn test_from_anki() { use burn::data::dataloader::Dataset; use burn::data::dataset::InMemDataset; - let dataset = InMemDataset::::new(anki_to_fsrs()); + let items = anki_to_fsrs(); + dbg!(&items.len()); + let dataset = InMemDataset::::new(items); let item = dataset.get(704).unwrap(); dbg!(&item);