From 8ed9f49bdc11492a6d499f3d64b9dc125102e11d Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 22 Aug 2024 16:34:19 +0800 Subject: [PATCH] Feat/FSRS Simulator (#3257) * test using existed cards * plot new and review * convert learning cards & use line chart * allow draw multiple simulations in the same chart * support hide simulation * convert x axis to Date * convert y from second to minute * support clear last simulation * remove unused import * rename * add hover/tooltip * fallback to default parameters * update default value and maximum of deckSize * add "processing..." * fix mistake --- qt/aqt/mediasrv.py | 1 + rslib/src/scheduler/fsrs/simulator.rs | 80 +++++--- ts/routes/deck-options/FsrsOptions.svelte | 178 ++++++++++++++++- ts/routes/graphs/simulator.ts | 220 ++++++++++++++++++++++ 4 files changed, 457 insertions(+), 22 deletions(-) create mode 100644 ts/routes/graphs/simulator.ts diff --git a/qt/aqt/mediasrv.py b/qt/aqt/mediasrv.py index a613a4795d0..c866ae423f2 100644 --- a/qt/aqt/mediasrv.py +++ b/qt/aqt/mediasrv.py @@ -625,6 +625,7 @@ def handle_on_main() -> None: "set_wants_abort", "evaluate_weights", "get_optimal_retention_parameters", + "simulate_fsrs_review", ] diff --git a/rslib/src/scheduler/fsrs/simulator.rs b/rslib/src/scheduler/fsrs/simulator.rs index 2396f14acee..b41d98d76eb 100644 --- a/rslib/src/scheduler/fsrs/simulator.rs +++ b/rslib/src/scheduler/fsrs/simulator.rs @@ -5,8 +5,10 @@ use anki_proto::scheduler::SimulateFsrsReviewRequest; use anki_proto::scheduler::SimulateFsrsReviewResponse; use fsrs::simulate; use fsrs::SimulatorConfig; +use fsrs::DEFAULT_PARAMETERS; use itertools::Itertools; +use crate::card::CardQueue; use crate::prelude::*; use crate::search::SortMode; @@ -22,9 +24,15 @@ impl Collection { .get_revlog_entries_for_searched_cards_in_card_order()?; let cards = guard.col.storage.all_searched_cards()?; drop(guard); + let days_elapsed = self.timing_today().unwrap().days_elapsed as i32; + let converted_cards = cards + .into_iter() + .filter(|c| c.queue != CardQueue::Suspended && c.queue != CardQueue::PreviewRepeat) + .filter_map(|c| Card::convert(c, days_elapsed, req.days_to_simulate)) + .collect_vec(); let p = self.get_optimal_retention_parameters(revlogs)?; let config = SimulatorConfig { - deck_size: req.deck_size as usize, + deck_size: req.deck_size as usize + converted_cards.len(), learn_span: req.days_to_simulate as usize, max_cost_perday: f32::MAX, max_ivl: req.max_interval as f32, @@ -40,7 +48,19 @@ impl Collection { learn_limit: req.new_limit as usize, review_limit: req.review_limit as usize, }; - let days_elapsed = self.timing_today().unwrap().days_elapsed as i32; + let parameters = if req.weights.is_empty() { + DEFAULT_PARAMETERS.to_vec() + } else if req.weights.len() != 19 { + if req.weights.len() == 17 { + let mut parameters = req.weights.to_vec(); + parameters.extend_from_slice(&[0.0, 0.0]); + parameters + } else { + return Err(AnkiError::FsrsWeightsInvalid); + } + } else { + req.weights.to_vec() + }; let ( accumulated_knowledge_acquisition, daily_review_count, @@ -48,15 +68,10 @@ impl Collection { daily_time_cost, ) = simulate( &config, - &req.weights, + ¶meters, req.desired_retention, None, - Some( - cards - .into_iter() - .filter_map(|c| Card::convert(c, days_elapsed)) - .collect_vec(), - ), + Some(converted_cards), ); Ok(SimulateFsrsReviewResponse { accumulated_knowledge_acquisition: accumulated_knowledge_acquisition.to_vec(), @@ -68,19 +83,42 @@ impl Collection { } impl Card { - fn convert(card: Card, days_elapsed: i32) -> Option { + fn convert(card: Card, days_elapsed: i32, day_to_simulate: u32) -> Option { match card.memory_state { - Some(state) => { - let due = card.original_or_current_due(); - let relative_due = due - days_elapsed; - Some(fsrs::Card { - difficulty: state.difficulty, - stability: state.stability, - last_date: (relative_due - card.interval as i32) as f32, - due: relative_due as f32, - }) - } - None => None, + Some(state) => match card.queue { + CardQueue::DayLearn | CardQueue::Review => { + let due = card.original_or_current_due(); + let relative_due = due - days_elapsed; + Some(fsrs::Card { + difficulty: state.difficulty, + stability: state.stability, + last_date: (relative_due - card.interval as i32) as f32, + due: relative_due as f32, + }) + } + CardQueue::New => Some(fsrs::Card { + difficulty: 1e-10, + stability: 1e-10, + last_date: 0.0, + due: day_to_simulate as f32, + }), + CardQueue::Learn | CardQueue::SchedBuried | CardQueue::UserBuried => { + Some(fsrs::Card { + difficulty: state.difficulty, + stability: state.stability, + last_date: 0.0, + due: 0.0, + }) + } + CardQueue::PreviewRepeat => None, + CardQueue::Suspended => None, + }, + None => Some(fsrs::Card { + difficulty: 1e-10, + stability: 1e-10, + last_date: 0.0, + due: day_to_simulate as f32, + }), } } } diff --git a/ts/routes/deck-options/FsrsOptions.svelte b/ts/routes/deck-options/FsrsOptions.svelte index e836e1dc8b4..569dee9ddc9 100644 --- a/ts/routes/deck-options/FsrsOptions.svelte +++ b/ts/routes/deck-options/FsrsOptions.svelte @@ -7,10 +7,15 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html ComputeRetentionProgress, type ComputeWeightsProgress, } from "@generated/anki/collection_pb"; - import { ComputeOptimalRetentionRequest } from "@generated/anki/scheduler_pb"; + import { + ComputeOptimalRetentionRequest, + SimulateFsrsReviewRequest, + type SimulateFsrsReviewResponse, + } from "@generated/anki/scheduler_pb"; import { computeFsrsWeights, computeOptimalRetention, + simulateFsrsReview, evaluateWeights, setWantsAbort, } from "@generated/backend"; @@ -28,6 +33,14 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html import Warning from "./Warning.svelte"; import WeightsInputRow from "./WeightsInputRow.svelte"; import WeightsSearchRow from "./WeightsSearchRow.svelte"; + import { renderSimulationChart, type Point } from "../graphs/simulator"; + import Graph from "../graphs/Graph.svelte"; + import HoverColumns from "../graphs/HoverColumns.svelte"; + import CumulativeOverlay from "../graphs/CumulativeOverlay.svelte"; + import AxisTicks from "../graphs/AxisTicks.svelte"; + import NoDataOverlay from "../graphs/NoDataOverlay.svelte"; + import TableData from "../graphs/TableData.svelte"; + import { defaultGraphBounds, type TableDatum } from "../graphs/graph-helpers"; export let state: DeckOptionsState; export let openHelpModal: (String) => void; @@ -68,6 +81,17 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html optimalRetentionRequest.daysToSimulate = 3650; } + const simulateFsrsRequest = new SimulateFsrsReviewRequest({ + weights: $config.fsrsWeights, + desiredRetention: $config.desiredRetention, + deckSize: 0, + daysToSimulate: 365, + newLimit: $config.newPerDay, + reviewLimit: $config.reviewsPerDay, + maxInterval: $config.maximumReviewInterval, + search: `preset:"${state.getCurrentName()}" -is:suspended`, + }); + function getRetentionWarning(retention: number): string { const decay = -0.5; const factor = 0.9 ** (1 / decay) - 1; @@ -256,6 +280,69 @@ License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html } return tr.deckConfigPredictedOptimalRetention({ num: retention.toFixed(2) }); } + + let tableData: TableDatum[] = [] as any; + const bounds = defaultGraphBounds(); + let svg = null as HTMLElement | SVGElement | null; + const title = tr.statisticsReviewsTitle(); + let simulationNumber = 0; + + let points: Point[] = []; + + function movingAverage(y: number[], windowSize: number): number[] { + const result: number[] = []; + for (let i = 0; i < y.length; i++) { + let sum = 0; + let count = 0; + for (let j = Math.max(0, i - windowSize + 1); j <= i; j++) { + sum += y[j]; + count++; + } + result.push(sum / count); + } + return result; + } + + $: simulateProgressString = ""; + + async function simulateFsrs(): Promise { + let resp: SimulateFsrsReviewResponse | undefined; + simulationNumber += 1; + try { + await runWithBackendProgress( + async () => { + simulateFsrsRequest.weights = $config.fsrsWeights; + simulateFsrsRequest.desiredRetention = $config.desiredRetention; + simulateFsrsRequest.search = `preset:"${state.getCurrentName()}" -is:suspended`; + simulateProgressString = "processing..."; + resp = await simulateFsrsReview(simulateFsrsRequest); + }, + () => {}, + ); + } finally { + if (resp) { + simulateProgressString = ""; + const dailyTimeCost = movingAverage( + resp.dailyTimeCost, + Math.round(simulateFsrsRequest.daysToSimulate / 50), + ); + points = points.concat( + dailyTimeCost.map((v, i) => ({ + x: i, + y: v, + label: simulationNumber, + })), + ); + tableData = renderSimulationChart(svg as SVGElement, bounds, points); + } + } + } + + function clearSimulation(): void { + points = points.filter((p) => p.label !== simulationNumber); + simulationNumber = Math.max(0, simulationNumber - 1); + tableData = renderSimulationChart(svg as SVGElement, bounds, points); + } +
+
+ FSRS simulator (experimental) + + + openHelpModal("simulateFsrsReview")}> + Days to simulate + + + + + openHelpModal("simulateFsrsReview")}> + Additional new cards to simulate + + + + + openHelpModal("simulateFsrsReview")}> + New cards/day + + + + + openHelpModal("simulateFsrsReview")}> + Maximum reviews/day + + + + + openHelpModal("simulateFsrsReview")}> + Maximum interval + + + + + + +
{simulateProgressString}
+ + + + + + + + + + + +
+
+ diff --git a/ts/routes/graphs/simulator.ts b/ts/routes/graphs/simulator.ts new file mode 100644 index 00000000000..52c7917c633 --- /dev/null +++ b/ts/routes/graphs/simulator.ts @@ -0,0 +1,220 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html +import { localizedNumber } from "@tslib/i18n"; +import { + axisBottom, + axisLeft, + bisector, + line, + max, + pointer, + rollup, + scaleLinear, + scaleTime, + schemeCategory10, + select, + timeFormat, +} from "d3"; + +import type { GraphBounds, TableDatum } from "./graph-helpers"; +import { setDataAvailable } from "./graph-helpers"; +import { hideTooltip, showTooltip } from "./tooltip"; + +export interface Point { + x: number; + y: number; + label: number; +} + +export function renderSimulationChart( + svgElem: SVGElement, + bounds: GraphBounds, + data: Point[], +): TableDatum[] { + const svg = select(svgElem); + svg.selectAll(".lines").remove(); + svg.selectAll(".hover-columns").remove(); + svg.selectAll(".focus-line").remove(); + svg.selectAll(".legend").remove(); + if (data.length == 0) { + setDataAvailable(svg, false); + return []; + } + const trans = svg.transition().duration(600) as any; + + // Prepare data + const today = new Date(); + const convertedData = data.map(d => ({ + ...d, + date: new Date(today.getTime() + d.x * 24 * 60 * 60 * 1000), + yMinutes: d.y / 60, + })); + const xMin = today; + const xMax = max(convertedData, d => d.date); + + const x = scaleTime() + .domain([xMin, xMax!]) + .range([bounds.marginLeft, bounds.width - bounds.marginRight]); + const formatDate = timeFormat("%Y-%m-%d"); + + svg.select(".x-ticks") + .call((selection) => + selection.transition(trans).call( + axisBottom(x) + .ticks(7) + .tickFormat((d: any) => formatDate(d)) + .tickSizeOuter(0), + ) + ) + .attr("direction", "ltr"); + // y scale + + const yTickFormat = (n: number): string => { + if (Math.round(n) != n) { + return ""; + } else { + return localizedNumber(n); + } + }; + + const yMax = max(convertedData, d => d.yMinutes)!; + const y = scaleLinear() + .range([bounds.height - bounds.marginBottom, bounds.marginTop]) + .domain([0, yMax]) + .nice(); + svg.select(".y-ticks") + .call((selection) => + selection.transition(trans).call( + axisLeft(y) + .ticks(bounds.height / 50) + .tickSizeOuter(0) + .tickFormat(yTickFormat as any), + ) + ) + .attr("direction", "ltr"); + + svg.select(".y-ticks") + .append("text") + .attr("class", "y-axis-title") + .attr("transform", "rotate(-90)") + .attr("y", 0 - bounds.marginLeft) + .attr("x", 0 - (bounds.height / 2)) + .attr("dy", "1em") + .attr("fill", "currentColor") + .style("text-anchor", "middle") + .text("Review Time per day (minutes)"); + + // x lines + const points = convertedData.map((d) => [x(d.date), y(d.yMinutes), d.label]); + const groups = rollup(points, v => Object.assign(v, { z: v[0][2] }), d => d[2]); + + const color = schemeCategory10; + + svg.append("g") + .attr("class", "lines") + .attr("fill", "none") + .attr("stroke-width", 1.5) + .attr("stroke-linejoin", "round") + .attr("stroke-linecap", "round") + .selectAll("path") + .data(Array.from(groups.entries())) + .join("path") + .style("mix-blend-mode", "multiply") + .attr("stroke", (d, i) => color[i % color.length]) + .attr("d", d => line()(d[1].map(p => [p[0], p[1]]))) + .attr("data-group", d => d[0]); + + const focusLine = svg.append("line") + .attr("class", "focus-line") + .attr("y1", bounds.marginTop) + .attr("y2", bounds.height - bounds.marginBottom) + .attr("stroke", "black") + .attr("stroke-width", 1) + .style("opacity", 0); + + const LongestGroupData = Array.from(groups.values()).reduce((a, b) => a.length > b.length ? a : b); + const barWidth = bounds.width / LongestGroupData.length; + + // hover/tooltip + svg.append("g") + .attr("class", "hover-columns") + .selectAll("rect") + .data(LongestGroupData) + .join("rect") + .attr("x", d => d[0] - barWidth / 2) + .attr("y", bounds.marginTop) + .attr("width", barWidth) + .attr("height", bounds.height - bounds.marginTop - bounds.marginBottom) + .attr("fill", "transparent") + .on("mousemove", mousemove) + .on("mouseout", hideTooltip); + + function mousemove(event: MouseEvent, d: any): void { + pointer(event, document.body); + const date = x.invert(d[0]); + + const groupData: { [key: string]: number } = {}; + + groups.forEach((groupPoints, key) => { + const bisect = bisector((d: number[]) => x.invert(d[0])).left; + const index = bisect(groupPoints, date); + const dataPoint = groupPoints[index - 1] || groupPoints[index]; + + if (dataPoint) { + groupData[key] = y.invert(dataPoint[1]); + } + }); + + focusLine.attr("x1", d[0]).attr("x2", d[0]).style("opacity", 1); + + let tooltipContent = `Date: ${timeFormat("%Y-%m-%d")(date)}
`; + for (const [key, value] of Object.entries(groupData)) { + tooltipContent += `Simulation ${key}: ${value.toFixed(2)} minutes
`; + } + + showTooltip(tooltipContent, event.pageX, event.pageY); + } + + const legend = svg.append("g") + .attr("class", "legend") + .attr("font-family", "sans-serif") + .attr("font-size", 10) + .attr("text-anchor", "start") + .selectAll("g") + .data(Array.from(groups.keys())) + .join("g") + .attr("transform", (d, i) => `translate(0,${i * 20})`) + .attr("cursor", "pointer") + .on("click", (event, d) => toggleGroup(event, d)); + + legend.append("rect") + .attr("x", bounds.width - bounds.marginRight + 10) + .attr("width", 19) + .attr("height", 19) + .attr("fill", (d, i) => color[i % color.length]); + + legend.append("text") + .attr("x", bounds.width - bounds.marginRight + 34) + .attr("y", 9.5) + .attr("dy", "0.32em") + .text(d => `Simulation ${d}`); + + const toggleGroup = (event: MouseEvent, d: number) => { + const group = d; + const path = svg.select(`path[data-group="${group}"]`); + const hidden = path.classed("hidden"); + const target = event.currentTarget as HTMLElement; + + path.classed("hidden", !hidden); + path.style("display", () => hidden ? null : "none"); + + select(target).select("rect") + .style("opacity", hidden ? 1 : 0.5); + }; + + setDataAvailable(svg, true); + + const tableData: TableDatum[] = []; + + return tableData; +}