Skip to content

Commit

Permalink
add explanation for the sql query (#23)
Browse files Browse the repository at this point in the history
* add explanation for sql query.
  • Loading branch information
asukaminato0721 authored Aug 24, 2023
1 parent 75c1622 commit 7d80f4e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 39 deletions.
60 changes: 26 additions & 34 deletions src/convertor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::collections::HashMap;

use crate::dataset::{FSRSItem, FSRSReview};

#[derive(Debug, Clone)]
#[derive(Clone, Debug, Default)]
struct RevlogEntry {
id: i64,
cid: i64,
Expand All @@ -24,8 +24,8 @@ fn row_to_revlog_entry(row: &Row) -> Result<RevlogEntry> {
})
}

fn read_collection() -> Vec<RevlogEntry> {
let db = Connection::open("tests/data/collection.anki21").unwrap();
fn read_collection() -> Result<Vec<RevlogEntry>> {
let db = Connection::open("tests/data/collection.anki21")?;
let filter_out_suspended_cards = false;
let filter_out_flags = vec![];
let flags_str = if !filter_out_flags.is_empty() {
Expand All @@ -34,7 +34,7 @@ fn read_collection() -> Vec<RevlogEntry> {
filter_out_flags
.iter()
.map(|x: &i32| x.to_string())
.collect::<Vec<String>>()
.collect::<Vec<_>>()
.join(", ")
)
} else {
Expand All @@ -48,32 +48,26 @@ fn read_collection() -> Vec<RevlogEntry> {
};

let current_timestamp = Utc::now().timestamp() * 1000;

let query = format!(
"SELECT id, cid, ease, type
FROM revlog
WHERE (type != 4 OR ivl <= 0)
AND (factor != 0 or type != 3)
AND id < {}
AND cid < {}
AND cid IN (
SELECT id
FROM cards
WHERE queue != 0
{}
{}
)",
current_timestamp, current_timestamp, suspended_cards_str, flags_str
);

// This sql query will be remove in the futrue. See https://github.com/open-spaced-repetition/fsrs-optimizer-burn/pull/14#issuecomment-1685895643
let revlogs = db
.prepare_cached(&query)
.unwrap()
.query_and_then([], row_to_revlog_entry)
.unwrap()
.collect::<Result<Vec<RevlogEntry>>>()
.unwrap();
revlogs
.prepare_cached(&format!(
"SELECT id, cid, ease, type
FROM revlog
WHERE (type != 4 OR ivl <= 0)
AND (factor != 0 or type != 3)
AND id < ?1
AND cid < ?2
AND cid IN (
SELECT id
FROM cards
WHERE queue != 0
{suspended_cards_str}
{flags_str}
)"
))?
.query_and_then((current_timestamp, current_timestamp), row_to_revlog_entry)?
.collect::<Result<Vec<_>>>()?;
Ok(revlogs)
}

fn group_by_cid(revlogs: Vec<RevlogEntry>) -> Vec<Vec<RevlogEntry>> {
Expand Down Expand Up @@ -136,9 +130,7 @@ fn convert_to_fsrs_items(

// Increment review_kind of all entries by 1
// 将所有 review_kind + 1
for entry in &mut entries {
entry.review_kind += 1;
}
entries.iter_mut().for_each(|entry| entry.review_kind += 1);

// Convert the timestamp and keep the first RevlogEntry for each date
// 转换时间戳并保留每个日期的第一个 RevlogEntry
Expand Down Expand Up @@ -198,7 +190,7 @@ fn convert_to_fsrs_items(
}

pub fn anki_to_fsrs() -> Vec<FSRSItem> {
let revlogs = read_collection();
let revlogs = read_collection().expect("read error");
let revlogs_per_card = group_by_cid(revlogs);
revlogs_per_card
.into_iter()
Expand All @@ -219,7 +211,7 @@ mod tests {
// https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip
#[test]
fn test() {
let revlogs = read_collection();
let revlogs = read_collection().unwrap();
let single_card_revlog = vec![revlogs
.iter()
.filter(|r| r.cid == 1528947214762)
Expand Down
4 changes: 2 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ fn test_forward() {
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0],
]);
let (stability, difficulty) = model.forward(delta_ts, ratings);
println!("stability {:?}", stability);
println!("difficulty {:?}", difficulty);
dbg!(&stability);
dbg!(&difficulty);
}

#[cfg(test)]
Expand Down
7 changes: 4 additions & 3 deletions src/weight_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ pub fn weight_clipper<B: Backend<FloatElem = f32>>(weights: Tensor<B, 1>) -> Ten

let val: &mut Vec<f32> = &mut weights.to_data().value;

for (i, w) in val.iter_mut().skip(4).enumerate() {
*w = w.clamp(CLAMPS[i].0, CLAMPS[i].1);
}
val.iter_mut()
.skip(4)
.zip(CLAMPS)
.for_each(|(w, (low, high))| *w = w.clamp(low, high));

Tensor::from_data(Data::new(val.clone(), weights.shape()))
}
Expand Down

0 comments on commit 7d80f4e

Please sign in to comment.