diff --git a/.github/workflows/auto-merge.yml b/.github/workflows/auto-merge.yml index ebc3879..8d53c68 100644 --- a/.github/workflows/auto-merge.yml +++ b/.github/workflows/auto-merge.yml @@ -36,5 +36,5 @@ jobs: git fetch --unshallow git checkout fsrs-browser git pull - git merge main -m "Auto-merge for $TAG" + git merge origin/main -m "Auto-merge for $TAG" git push diff --git a/Cargo.lock b/Cargo.lock index 31ee398..161ac3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.8.11" @@ -62,6 +73,12 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" +[[package]] +name = "anyhow" +version = "1.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" + [[package]] name = "arrayvec" version = "0.7.4" @@ -94,6 +111,18 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +[[package]] +name = "base64" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" + +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bincode" version = "2.0.0-rc.3" @@ -151,6 +180,8 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burn" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e041d5f4eef703500763e599050cba419cd90d464172d71e3d5397baebbf1d8a" dependencies = [ "burn-core", "burn-train", @@ -159,6 +190,8 @@ dependencies = [ [[package]] name = "burn-autodiff" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e23c815bc728ac60343b8820fb71e9b4a2c0cb283bfd58828246caacabe6eff" dependencies = [ "burn-common", "burn-tensor", @@ -170,6 +203,8 @@ dependencies = [ [[package]] name = "burn-candle" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d319a88254df7e9740154c32e862d721d29e5f782c0fdf7004f6b9ed5c8369f" dependencies = [ "burn-tensor", "candle-core", @@ -180,6 +215,8 @@ dependencies = [ [[package]] name = "burn-common" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a14cddb7f93dc985637e21f068a343acdfc4d62232fb11101f88c2739abad249" dependencies = [ "async-trait", "derive-new", @@ -199,6 +236,8 @@ dependencies = [ [[package]] name = "burn-compute" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbe641bbe653d04fb070a80946f3db13485e04d7d12104aab9287a1d55b3493c" dependencies = [ "burn-common", "derive-new", @@ -215,6 +254,8 @@ dependencies = [ [[package]] name = "burn-core" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f3532e2f722bca39aefa69aea2b8e6cf2c3bf70f95ba8421b557082d89ea476" dependencies = [ "bincode", "burn-autodiff", @@ -241,6 +282,8 @@ dependencies = [ [[package]] name = "burn-dataset" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebb03147d7c50f31c673ee7f672543caddd56bc5de906810db23e396ca062054" dependencies = [ "csv", "derive-new", @@ -265,6 +308,8 @@ dependencies = [ [[package]] name = "burn-derive" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dbf7e7f4154821f1a74c709ed2191304701e6f56b6221aec8585b8a16d16ae5" dependencies = [ "derive-new", "proc-macro2", @@ -272,9 +317,26 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "burn-fusion" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "934015329ca3b41a6a6bc7b6a4eedcda04d899085e0b3273e7fb330358c15cf8" +dependencies = [ + "burn-common", + "burn-tensor", + "derive-new", + "hashbrown 0.14.3", + "log", + "serde", + "spin", +] + [[package]] name = "burn-jit" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d257cec36c1b4404c79355492a0c32d0775ed5d7826241051323eb88f1e633dc" dependencies = [ "burn-common", "burn-compute", @@ -293,6 +355,8 @@ dependencies = [ [[package]] name = "burn-ndarray" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f3a7d13e0116b4e442bda45aa9eb8a4cc3b70cf7d67197b13d539753275428c" dependencies = [ "burn-autodiff", "burn-common", @@ -307,9 +371,24 @@ dependencies = [ "spin", ] +[[package]] +name = "burn-tch" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ee78099b81128ba1122c645344cb7126c1fadfc05b284150efd94731001f0a7" +dependencies = [ + "burn-tensor", + "half", + "libc", + "rand", + "tch", +] + [[package]] name = "burn-tensor" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9395b25136b8fff2ca293dc30e8ca915cc811ed48ffbb147063b6c9c7fcba6a" dependencies = [ "burn-common", "derive-new", @@ -324,6 +403,8 @@ dependencies = [ [[package]] name = "burn-train" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da95f83ed597cdb313fb0e18b389f88b96d5bcd1a37620adc969fe2934d486ff" dependencies = [ "burn-common", "burn-core", @@ -338,6 +419,8 @@ dependencies = [ [[package]] name = "burn-wgpu" version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6377670147b65387c807938b4f77a0b149b154ecc8b749f66ad068d345efac14" dependencies = [ "burn-common", "burn-compute", @@ -393,7 +476,7 @@ dependencies = [ "rand", "rand_distr", "rayon", - "safetensors", + "safetensors 0.4.3", "thiserror", "yoke", "zip", @@ -410,6 +493,11 @@ name = "cc" version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d32a725bc159af97c3e629873bb9f88fb8cf8a4867175f76dc987815ea07c83b" +dependencies = [ + "jobserver", + "libc", + "once_cell", +] [[package]] name = "cfg-if" @@ -556,6 +644,28 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d375883580a668c7481ea6631fc1a8863e33cc335bf56bfad8d7e6d4b04b13a5" +dependencies = [ + "com_macros_support", + "proc-macro2", + "syn 1.0.109", +] + +[[package]] +name = "com_macros_support" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad899a1087a9296d5644792d7cb72b8e34c1bec8e7d4fbc002230169a6e8710c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -907,7 +1017,7 @@ checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" [[package]] name = "fsrs" -version = "0.6.1" +version = "0.6.2" dependencies = [ "burn", "chrono", @@ -925,7 +1035,6 @@ dependencies = [ "serde", "snafu", "strum 0.26.2", - "wasm-bindgen", ] [[package]] @@ -1299,6 +1408,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -1391,6 +1509,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + [[package]] name = "jpeg-decoder" version = "0.3.1" @@ -2098,6 +2225,21 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rmp" version = "0.8.14" @@ -2153,6 +2295,37 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" + +[[package]] +name = "rustls-webpki" +version = "0.102.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.15" @@ -2175,6 +2348,16 @@ dependencies = [ "serde_json", ] +[[package]] +name = "safetensors" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ced76b22c7fba1162f11a5a75d9d8405264b467a07ae0c9c29be119b9297db9" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -2414,6 +2597,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.60" @@ -2626,6 +2820,15 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-width" version = "0.1.11" @@ -2638,6 +2841,42 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +dependencies = [ + "base64", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "rustls-webpki", + "serde", + "serde_json", + "url", + "webpki-roots", +] + +[[package]] +name = "url" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "uuid" version = "1.8.0" @@ -2761,10 +3000,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" [[package]] -name = "wasm_sync" -version = "0.1.2" +name = "web-sys" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff360cade7fec41ff0e9d2cda57fe58258c5f16def0e21302394659e6bbb0ea" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" dependencies = [ "js-sys", "wasm-bindgen", @@ -3181,6 +3420,35 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.10+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "zune-inflate" version = "0.2.54" diff --git a/Cargo.toml b/Cargo.toml index 9451b56..f81adcd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.6.1" +version = "0.6.2" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/convertor_tests.rs b/src/convertor_tests.rs index 060aeae..df1ce00 100644 --- a/src/convertor_tests.rs +++ b/src/convertor_tests.rs @@ -75,23 +75,19 @@ fn remove_revlog_before_last_first_learn(entries: Vec) -> Vec NaiveDate { - let timestamp_seconds = timestamp - next_day_starts_at * 3600 * 1000; - let datetime = Utc - .timestamp_millis_opt(timestamp_seconds) - .unwrap() - .with_timezone(&timezone); +fn convert_to_date(timestamp: i64, minute_offset: i32) -> NaiveDate { + let timestamp_seconds = timestamp + i64::from(minute_offset) * 60 * 1000; + let datetime = Utc.timestamp_millis_opt(timestamp_seconds).unwrap(); datetime.date_naive() } fn keep_first_revlog_same_date( mut entries: Vec, - next_day_starts_at: i64, - timezone: Tz, + minute_offset: i32, ) -> Vec { let mut unique_dates = std::collections::HashSet::new(); entries.retain(|entry| { - let date = convert_to_date(entry.id, next_day_starts_at, timezone); + let date = convert_to_date(entry.id, minute_offset); unique_dates.insert(date) }); entries @@ -102,17 +98,16 @@ fn keep_first_revlog_same_date( fn convert_to_fsrs_items( mut entries: Vec, - next_day_starts_at: i64, - timezone: Tz, + minute_offset: i32, ) -> Option> { // entries = filter_out_cram(entries); // entries = filter_out_manual(entries); entries = remove_revlog_before_last_first_learn(entries); - entries = keep_first_revlog_same_date(entries, next_day_starts_at, timezone); + entries = keep_first_revlog_same_date(entries, minute_offset); for i in 1..entries.len() { - let date_current = convert_to_date(entries[i].id, next_day_starts_at, timezone); - let date_previous = convert_to_date(entries[i - 1].id, next_day_starts_at, timezone); + let date_current = convert_to_date(entries[i].id, minute_offset); + let date_previous = convert_to_date(entries[i - 1].id, minute_offset); entries[i].last_interval = (date_current - date_previous).num_days() as i32; } @@ -178,20 +173,19 @@ impl Into for u8 { } /// Convert a series of revlog entries sorted by card id into FSRS items. -pub fn anki_to_fsrs(revlogs: Vec) -> Vec { +pub fn anki_to_fsrs(revlogs: Vec, minute_offset: i32) -> Vec { let mut revlogs = revlogs .into_iter() .group_by(|r| r.cid) .into_iter() - .filter_map(|(_cid, entries)| { - convert_to_fsrs_items(entries.collect(), 4, Tz::Asia__Shanghai) - }) + .filter_map(|(_cid, entries)| convert_to_fsrs_items(entries.collect(), minute_offset)) .flatten() .collect_vec(); revlogs.sort_by_cached_key(|r| r.reviews.len()); revlogs } +/* #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct RevlogCsv { // card_id,review_time,review_rating,review_state,review_duration @@ -202,7 +196,6 @@ pub struct RevlogCsv { pub review_duration: u32, } -/* pub(crate) fn data_from_csv() -> Vec { const CSV_FILE: &str = "tests/data/revlog.csv"; let rdr = csv::ReaderBuilder::new(); @@ -393,6 +386,7 @@ fn extract_simulator_config_from_revlog() { // This test currently expects the following .anki21 file to be placed in tests/data/: // https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip +/* #[test] fn conversion_works() { let revlogs = read_collection().unwrap(); @@ -556,7 +550,6 @@ fn ordering_of_inputs_should_not_change() { ); } -/* const NEXT_DAY_AT: i64 = 86400 * 100; fn revlog(review_kind: RevlogReviewKind, days_ago: i64) -> RevlogEntry { @@ -567,7 +560,6 @@ fn revlog(review_kind: RevlogReviewKind, days_ago: i64) -> RevlogEntry { ..Default::default() } } -*/ #[test] fn delta_t_is_correct() -> Result<()> { @@ -659,7 +651,7 @@ fn delta_t_is_correct() -> Result<()> { Ok(()) } -/* + #[test] fn test_filter_out_cram() { let revlog_vec = vec![ diff --git a/src/inference.rs b/src/inference.rs index 237a738..0b1787a 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -23,8 +23,8 @@ pub type Parameters = [f32]; use itertools::izip; pub static DEFAULT_PARAMETERS: [f32; 17] = [ - 0.5701, 1.4436, 4.1386, 10.9355, 5.1443, 1.2006, 0.8627, 0.0362, 1.629, 0.1342, 1.0166, 2.1174, - 0.0839, 0.3204, 1.4676, 0.219, 2.8237, + 0.4872, 1.4003, 3.7145, 13.8206, 5.1618, 1.2298, 0.8975, 0.031, 1.6474, 0.1367, 1.0461, 2.1072, + 0.0793, 0.3246, 1.587, 0.2272, 2.8755, ]; fn infer( @@ -481,7 +481,7 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.203_888, 0.029_732]), 5); + .assert_approx_eq(&Data::from([0.204_330, 0.031_510]), 5); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); @@ -494,7 +494,7 @@ mod tests { .unwrap(); Data::from([self_by_other, other_by_self]) - .assert_approx_eq(&Data::from([0.014_089, 0.016_483]), 5); + .assert_approx_eq(&Data::from([0.013_520, 0.019_003]), 5); Ok(()) } @@ -581,7 +581,7 @@ mod tests { fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(), MemoryState { stability: 9.999995, - difficulty: 7.255334 + difficulty: 7.4120417 } ); assert_eq!( diff --git a/src/model.rs b/src/model.rs index 632725b..1052b01 100644 --- a/src/model.rs +++ b/src/model.rs @@ -283,7 +283,14 @@ mod tests { let stability = model.init_stability(rating); assert_eq!( stability.to_data(), - Data::from([0.5701, 1.4436, 4.1386, 10.9355, 0.5701, 1.4436]) + Data::from([ + DEFAULT_PARAMETERS[0], + DEFAULT_PARAMETERS[1], + DEFAULT_PARAMETERS[2], + DEFAULT_PARAMETERS[3], + DEFAULT_PARAMETERS[0], + DEFAULT_PARAMETERS[1] + ]) ) } @@ -295,7 +302,14 @@ mod tests { let difficulty = model.init_difficulty(rating); assert_eq!( difficulty.to_data(), - Data::from([7.5455, 6.3449, 5.1443, 3.9436998, 7.5455, 6.3449]) + Data::from([ + DEFAULT_PARAMETERS[4] + 2.0 * DEFAULT_PARAMETERS[5], + DEFAULT_PARAMETERS[4] + DEFAULT_PARAMETERS[5], + DEFAULT_PARAMETERS[4], + DEFAULT_PARAMETERS[4] - DEFAULT_PARAMETERS[5], + DEFAULT_PARAMETERS[4] + 2.0 * DEFAULT_PARAMETERS[5], + DEFAULT_PARAMETERS[4] + DEFAULT_PARAMETERS[5] + ]) ) } @@ -331,13 +345,18 @@ mod tests { next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.7254, 5.8627, 5.0, 4.1373]) + Data::from([ + 5.0 + 2.0 * DEFAULT_PARAMETERS[6], + 5.0 + DEFAULT_PARAMETERS[6], + 5.0, + 5.0 - DEFAULT_PARAMETERS[6] + ]) ); let next_difficulty = model.mean_reversion(next_difficulty); next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.6681643, 5.836694, 5.0052238, 4.1737533]) + Data::from([6.744371, 5.8746934, 5.005016, 4.1353383]) ) } @@ -358,19 +377,19 @@ mod tests { s_recall.clone().backward(); assert_eq!( s_recall.to_data(), - Data::from([26.980938, 14.128489, 63.600677, 208.72739]) + Data::from([27.980768, 14.916422, 66.45966, 222.94603]) ); let s_forget = model.stability_after_failure(stability, difficulty, retention); s_forget.clone().backward(); assert_eq!( s_forget.to_data(), - Data::from([1.9016013, 2.0777824, 2.3257504, 2.6291647]) + Data::from([1.9482934, 2.161251, 2.4528089, 2.8098207]) ); let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget); next_stability.clone().backward(); assert_eq!( next_stability.to_data(), - Data::from([1.9016013, 14.128489, 63.600677, 208.72739]) + Data::from([1.9482934, 14.916422, 66.45966, 222.94603]) ) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 1c971bf..70d67bf 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -607,7 +607,7 @@ mod tests { ); assert_eq!( memorized_cnt_per_day[memorized_cnt_per_day.len() - 1], - 3130.8465582271774 + 3199.9526251977177 ) } @@ -663,8 +663,8 @@ mod tests { assert_eq!( results.1.to_vec(), vec![ - 0, 16, 27, 34, 84, 80, 91, 92, 103, 107, 111, 113, 138, 132, 133, 116, 134, 148, - 152, 162, 172, 177, 188, 189, 200, 185, 185, 200, 198, 200 + 0, 16, 27, 34, 84, 80, 91, 92, 104, 106, 109, 112, 133, 123, 139, 121, 136, 149, + 136, 159, 173, 178, 175, 180, 189, 181, 196, 200, 193, 196 ] ); assert_eq!( @@ -687,7 +687,7 @@ mod tests { ..Default::default() }; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8263932); + assert_eq!(optimal_retention, 0.8419900928572013); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) } diff --git a/src/pre_training.rs b/src/pre_training.rs index 872c1dc..175b1ed 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -100,8 +100,8 @@ fn loss( let y_pred = power_forgetting_curve(delta_t, init_s0); let logloss = (-(recall * y_pred.clone().mapv_into(|v| v.ln()) + (1.0 - recall) * (1.0 - &y_pred).mapv_into(|v| v.ln())) - * count.mapv(|v| v.sqrt())) - .sum(); + * count) + .sum(); let l1 = (init_s0 - default_s0).abs() / 16.0; logloss + l1 } @@ -293,11 +293,9 @@ mod tests { let count = Array1::from(vec![435.0, 97.0, 63.0, 38.0, 28.0]); let default_s0 = DEFAULT_PARAMETERS[0] as f64; let actual = loss(&delta_t, &recall, &count, 1.017056, default_s0); - dbg!(actual); - assert_eq!(actual, 22.922578338789826); + assert_eq!(actual, 280.7447802452844); let actual = loss(&delta_t, &recall, &count, 1.017011, default_s0); - dbg!(actual); - assert_eq!(actual, 22.922578344493953); + assert_eq!(actual, 280.7444462249327); } #[test] @@ -335,7 +333,7 @@ mod tests { )]); let actual = search_parameters(pretrainset, 0.9430285915990116); Data::from([*actual.get(&first_rating).unwrap()]) - .assert_approx_eq(&Data::from([1.017_056]), 6); + .assert_approx_eq(&Data::from([0.908_688]), 6); } #[test] @@ -347,10 +345,8 @@ mod tests { (pretrainset, trainset) = filter_outlier(pretrainset, trainset); let items = [pretrainset.clone(), trainset].concat(); let average_recall = calculate_average_recall(&items); - Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq( - &Data::from([1.017_056, 1.829_625, 4.414_563, 10.935_500]), - 6, - ) + Data::from(pretrain(pretrainset, average_recall).unwrap()) + .assert_approx_eq(&Data::from([0.908_688, 1.678_973, 4.216_837, 9.615_904]), 6) } #[test] @@ -363,6 +359,6 @@ mod tests { let mut rating_stability = HashMap::from([(2, 0.35)]); let rating_count = HashMap::from([(2, 1)]); let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap(); - assert_eq!(actual, [0.13822041, 0.35, 1.0034012, 2.6513057,]); + assert_eq!(actual, [0.1217739, 0.35, 0.928426, 3.4544096]); } }