diff --git a/optuna_dashboard/ts/components/GraphEdf.tsx b/optuna_dashboard/ts/components/GraphEdf.tsx index e5c3a083..234b2e6f 100644 --- a/optuna_dashboard/ts/components/GraphEdf.tsx +++ b/optuna_dashboard/ts/components/GraphEdf.tsx @@ -1,9 +1,4 @@ -import { - GraphContainer, - PlotEdf, - getPlotDomId, - useGraphComponentState, -} from "@optuna/react" +import { GraphContainer, PlotEdf, useGraphComponentState } from "@optuna/react" import * as plotly from "plotly.js-dist-min" import React, { FC, useEffect } from "react" import { StudyDetail } from "ts/types/optuna" @@ -22,6 +17,8 @@ export const GraphEdf: FC<{ } } +const domId = "graph-edf" + const GraphEdfBackend: FC<{ studies: StudyDetail[] }> = ({ studies }) => { @@ -29,7 +26,6 @@ const GraphEdfBackend: FC<{ const { graphComponentState, notifyGraphDidRender } = useGraphComponentState() const studyIds = studies.map((s) => s.id) - const domId = getPlotDomId(-1) const numCompletedTrials = studies.reduce( (acc, study) => acc + study?.trials.filter((t) => t.state === "Complete").length, diff --git a/optuna_dashboard/ts/components/GraphSlice.tsx b/optuna_dashboard/ts/components/GraphSlice.tsx index d404c2cf..d758c661 100644 --- a/optuna_dashboard/ts/components/GraphSlice.tsx +++ b/optuna_dashboard/ts/components/GraphSlice.tsx @@ -1,41 +1,14 @@ -import { - FormControl, - FormLabel, - Grid, - MenuItem, - Select, - SelectChangeEvent, - Switch, - Typography, - useTheme, -} from "@mui/material" import { GraphContainer, + PlotSlice, useGraphComponentState, - useMergedUnionSearchSpace, -} from "@optuna/react" -import { - Target, - useFilteredTrials, - useObjectiveAndUserAttrTargets, - useParamTargets, } from "@optuna/react" -import * as Optuna from "@optuna/types" import * as plotly from "plotly.js-dist-min" -import React, { FC, useEffect, useState } from "react" -import { SearchSpaceItem, StudyDetail } from "ts/types/optuna" +import React, { FC, useEffect } from "react" +import { StudyDetail } from "ts/types/optuna" import { PlotType } from "../apiClient" import { usePlot } from "../hooks/usePlot" -import { useBackendRender, usePlotlyColorTheme } from "../state" - -const plotDomId = "graph-slice" - -const isLogScale = (s: SearchSpaceItem): boolean => { - if (s.distribution.type === "CategoricalDistribution") { - return false - } - return s.distribution.log -} +import { useBackendRender } from "../state" export const GraphSlice: FC<{ study: StudyDetail | null @@ -43,10 +16,12 @@ export const GraphSlice: FC<{ if (useBackendRender()) { return } else { - return + return } } +const domId = "graph-slice" + const GraphSliceBackend: FC<{ study: StudyDetail | null }> = ({ study = null }) => { @@ -64,7 +39,7 @@ const GraphSliceBackend: FC<{ useEffect(() => { if (data && layout && graphComponentState !== "componentWillMount") { - plotly.react(plotDomId, data, layout).then(notifyGraphDidRender) + plotly.react(domId, data, layout).then(notifyGraphDidRender) } }, [data, layout, graphComponentState]) useEffect(() => { @@ -75,253 +50,8 @@ const GraphSliceBackend: FC<{ return ( ) } - -const GraphSliceFrontend: FC<{ - study: StudyDetail | null -}> = ({ study = null }) => { - const { graphComponentState, notifyGraphDidRender } = useGraphComponentState() - - const theme = useTheme() - const colorTheme = usePlotlyColorTheme(theme.palette.mode) - - const [objectiveTargets, selectedObjective, setObjectiveTarget] = - useObjectiveAndUserAttrTargets(study) - const searchSpace = useMergedUnionSearchSpace(study?.union_search_space) - const [paramTargets, selectedParamTarget, setParamTarget] = - useParamTargets(searchSpace) - const [logYScale, setLogYScale] = useState(false) - - const trials = useFilteredTrials( - study, - selectedParamTarget !== null - ? [selectedObjective, selectedParamTarget] - : [selectedObjective], - false - ) - - useEffect(() => { - if (graphComponentState !== "componentWillMount") { - plotSlice( - trials, - selectedObjective, - selectedParamTarget, - searchSpace.find((s) => s.name === selectedParamTarget?.key) || null, - logYScale, - colorTheme - )?.then(notifyGraphDidRender) - } - }, [ - trials, - selectedObjective, - searchSpace, - selectedParamTarget, - logYScale, - colorTheme, - graphComponentState, - ]) - - const handleObjectiveChange = (event: SelectChangeEvent) => { - setObjectiveTarget(event.target.value) - } - - const handleSelectedParam = (e: SelectChangeEvent) => { - setParamTarget(e.target.value) - } - - const handleLogYScaleChange = () => { - setLogYScale(!logYScale) - } - - return ( - - - - Slice - - {objectiveTargets.length !== 1 && ( - - Objective: - - - )} - {paramTargets.length !== 0 && selectedParamTarget !== null && ( - - Parameter: - - - )} - - Log y scale: - - - - - - - - ) -} - -const plotSlice = ( - trials: Optuna.Trial[], - objectiveTarget: Target, - selectedParamTarget: Target | null, - selectedParamSpace: SearchSpaceItem | null, - logYScale: boolean, - colorTheme: Partial -) => { - if (document.getElementById(plotDomId) === null) { - return - } - - const layout: Partial = { - margin: { - l: 50, - t: 0, - r: 50, - b: 0, - }, - xaxis: { - title: selectedParamTarget?.toLabel() || "", - type: - selectedParamSpace !== null && isLogScale(selectedParamSpace) - ? "log" - : "linear", - gridwidth: 1, - automargin: true, - }, - yaxis: { - title: "Objective Value", - type: logYScale ? "log" : "linear", - gridwidth: 1, - automargin: true, - }, - showlegend: false, - uirevision: "true", - template: colorTheme, - } - if ( - selectedParamSpace === null || - selectedParamTarget === null || - trials.length === 0 - ) { - return plotly.react(plotDomId, [], layout) - } - - const feasibleTrials: Optuna.Trial[] = [] - const infeasibleTrials: Optuna.Trial[] = [] - trials.forEach((t) => { - if (t.constraints.every((c) => c <= 0)) { - feasibleTrials.push(t) - } else { - infeasibleTrials.push(t) - } - }) - - const feasibleObjectiveValues: number[] = feasibleTrials.map( - (t) => objectiveTarget.getTargetValue(t) as number - ) - const infeasibleObjectiveValues: number[] = infeasibleTrials.map( - (t) => objectiveTarget.getTargetValue(t) as number - ) - - const feasibleValues = feasibleTrials.map( - (t) => selectedParamTarget.getTargetValue(t) as number - ) - const infeasibleValues = infeasibleTrials.map( - (t) => selectedParamTarget.getTargetValue(t) as number - ) - const trace: plotly.Data[] = [ - { - type: "scatter", - x: feasibleValues, - y: feasibleObjectiveValues, - mode: "markers", - name: "Feasible Trial", - marker: { - color: feasibleTrials.map((t) => t.number), - colorscale: "Blues", - reversescale: true, - colorbar: { - title: "Trial", - }, - line: { - color: "Grey", - width: 0.5, - }, - }, - }, - { - type: "scatter", - x: infeasibleValues, - y: infeasibleObjectiveValues, - mode: "markers", - name: "Infeasible Trial", - marker: { - color: "#cccccc", - reversescale: true, - }, - }, - ] - if (selectedParamSpace.distribution.type !== "CategoricalDistribution") { - layout["xaxis"] = { - title: selectedParamTarget.toLabel(), - type: isLogScale(selectedParamSpace) ? "log" : "linear", - gridwidth: 1, - automargin: true, // Otherwise the label is outside of the plot - } - } else { - const vocabArr = selectedParamSpace.distribution.choices.map( - (c) => c?.toString() ?? "null" - ) - const tickvals: number[] = vocabArr.map((v, i) => i) - layout["xaxis"] = { - title: selectedParamTarget.toLabel(), - type: "linear", - gridwidth: 1, - tickvals: tickvals, - ticktext: vocabArr, - automargin: true, // Otherwise the label is outside of the plot - } - } - return plotly.react(plotDomId, trace, layout) -} diff --git a/tslib/react/src/components/PlotEdf.tsx b/tslib/react/src/components/PlotEdf.tsx index 17a0872e..56dce17e 100644 --- a/tslib/react/src/components/PlotEdf.tsx +++ b/tslib/react/src/components/PlotEdf.tsx @@ -12,7 +12,7 @@ export type EdfPlotInfo = { trials: Optuna.Trial[] } -export const getPlotDomId = (objectiveId: number) => `graph-edf-${objectiveId}` +const getPlotDomId = (objectiveId: number) => `plot-edf-${objectiveId}` export const PlotEdf: FC<{ studies: Optuna.Study[] diff --git a/tslib/react/src/components/PlotSlice.stories.tsx b/tslib/react/src/components/PlotSlice.stories.tsx new file mode 100644 index 00000000..bad059d1 --- /dev/null +++ b/tslib/react/src/components/PlotSlice.stories.tsx @@ -0,0 +1,37 @@ +import { CssBaseline, ThemeProvider } from "@mui/material" +import { Meta, StoryObj } from "@storybook/react" +import React from "react" +import { useMockStudy } from "../MockStudies" +import { lightTheme } from "../styles/lightTheme" +import { PlotSlice } from "./PlotSlice" + +const meta: Meta = { + component: PlotSlice, + title: "PlotSlice", + tags: ["autodocs"], + decorators: [ + (Story, storyContext) => { + const { study } = useMockStudy(storyContext.parameters?.studyId) + if (!study) return

loading...

+ return ( + + + + + ) + }, + ], +} + +export default meta +type Story = StoryObj + +export const MockStudyExample1: Story = { + parameters: { + studyId: 1, + }, +} diff --git a/tslib/react/src/components/PlotSlice.tsx b/tslib/react/src/components/PlotSlice.tsx new file mode 100644 index 00000000..3a70c5e8 --- /dev/null +++ b/tslib/react/src/components/PlotSlice.tsx @@ -0,0 +1,281 @@ +import { + FormControl, + FormLabel, + Grid, + MenuItem, + Select, + SelectChangeEvent, + Switch, + Typography, + useTheme, +} from "@mui/material" +import * as Optuna from "@optuna/types" +import * as plotly from "plotly.js-dist-min" +import { FC, useEffect, useState } from "react" +import { useGraphComponentState } from "../hooks/useGraphComponentState" +import { useMergedUnionSearchSpace } from "../utils/searchSpace" +import { + Target, + useFilteredTrials, + useObjectiveAndUserAttrTargets, + useParamTargets, +} from "../utils/trialFilter" +import { GraphContainer } from "./GraphContainer" +import { plotlyDarkTemplate } from "./PlotlyDarkMode" + +const isLogScale = (s: Optuna.SearchSpaceItem): boolean => { + if (s.distribution.type === "CategoricalDistribution") { + return false + } + return s.distribution.log +} + +const domId = "plot-slice" + +export const PlotSlice: FC<{ + study: Optuna.Study | null +}> = ({ study = null }) => { + const { graphComponentState, notifyGraphDidRender } = useGraphComponentState() + + const theme = useTheme() + + const [objectiveTargets, selectedObjective, setObjectiveTarget] = + useObjectiveAndUserAttrTargets(study) + const searchSpace = useMergedUnionSearchSpace(study?.union_search_space) + const [paramTargets, selectedParamTarget, setParamTarget] = + useParamTargets(searchSpace) + const [logYScale, setLogYScale] = useState(false) + + const trials = useFilteredTrials( + study, + selectedParamTarget !== null + ? [selectedObjective, selectedParamTarget] + : [selectedObjective], + false + ) + + // biome-ignore lint/correctness/useExhaustiveDependencies: + useEffect(() => { + if (graphComponentState !== "componentWillMount") { + plotSlice( + trials, + selectedObjective, + selectedParamTarget, + searchSpace.find((s) => s.name === selectedParamTarget?.key) || null, + logYScale, + theme.palette.mode + )?.then(notifyGraphDidRender) + } + }, [ + trials, + selectedObjective, + searchSpace, + selectedParamTarget, + logYScale, + theme.palette.mode, + graphComponentState, + ]) + + const handleObjectiveChange = (event: SelectChangeEvent) => { + setObjectiveTarget(event.target.value) + } + + const handleSelectedParam = (e: SelectChangeEvent) => { + setParamTarget(e.target.value) + } + + const handleLogYScaleChange = () => { + setLogYScale(!logYScale) + } + + return ( + + + + Slice + + {objectiveTargets.length !== 1 && ( + + Objective: + + + )} + {paramTargets.length !== 0 && selectedParamTarget !== null && ( + + Parameter: + + + )} + + Log y scale: + + + + + + + + ) +} + +const plotSlice = ( + trials: Optuna.Trial[], + objectiveTarget: Target, + selectedParamTarget: Target | null, + selectedParamSpace: Optuna.SearchSpaceItem | null, + logYScale: boolean, + mode: string +) => { + if (document.getElementById(domId) === null) { + return + } + + const layout: Partial = { + margin: { + l: 50, + t: 0, + r: 50, + b: 0, + }, + xaxis: { + title: selectedParamTarget?.toLabel() || "", + type: + selectedParamSpace !== null && isLogScale(selectedParamSpace) + ? "log" + : "linear", + gridwidth: 1, + automargin: true, + }, + yaxis: { + title: "Objective Value", + type: logYScale ? "log" : "linear", + gridwidth: 1, + automargin: true, + }, + showlegend: false, + uirevision: "true", + template: mode === "dark" ? plotlyDarkTemplate : {}, + } + if ( + selectedParamSpace === null || + selectedParamTarget === null || + trials.length === 0 + ) { + return plotly.react(domId, [], layout) + } + + const feasibleTrials: Optuna.Trial[] = [] + const infeasibleTrials: Optuna.Trial[] = [] + // biome-ignore lint/complexity/noForEach: + trials.forEach((t) => { + if (t.constraints.every((c) => c <= 0)) { + feasibleTrials.push(t) + } else { + infeasibleTrials.push(t) + } + }) + + const feasibleObjectiveValues: number[] = feasibleTrials.map( + (t) => objectiveTarget.getTargetValue(t) as number + ) + const infeasibleObjectiveValues: number[] = infeasibleTrials.map( + (t) => objectiveTarget.getTargetValue(t) as number + ) + + const feasibleValues = feasibleTrials.map( + (t) => selectedParamTarget.getTargetValue(t) as number + ) + const infeasibleValues = infeasibleTrials.map( + (t) => selectedParamTarget.getTargetValue(t) as number + ) + const trace: plotly.Data[] = [ + { + type: "scatter", + x: feasibleValues, + y: feasibleObjectiveValues, + mode: "markers", + name: "Feasible Trial", + marker: { + color: feasibleTrials.map((t) => t.number), + colorscale: "Blues", + reversescale: true, + colorbar: { + title: "Trial", + }, + line: { + color: "Grey", + width: 0.5, + }, + }, + }, + { + type: "scatter", + x: infeasibleValues, + y: infeasibleObjectiveValues, + mode: "markers", + name: "Infeasible Trial", + marker: { + color: "#cccccc", + reversescale: true, + }, + }, + ] + if (selectedParamSpace.distribution.type !== "CategoricalDistribution") { + layout.xaxis = { + title: selectedParamTarget.toLabel(), + type: isLogScale(selectedParamSpace) ? "log" : "linear", + gridwidth: 1, + automargin: true, // Otherwise the label is outside of the plot + } + } else { + const vocabArr = selectedParamSpace.distribution.choices.map( + (c) => c?.toString() ?? "null" + ) + const tickvals: number[] = vocabArr.map((_v, i) => i) + layout.xaxis = { + title: selectedParamTarget.toLabel(), + type: "linear", + gridwidth: 1, + tickvals: tickvals, + ticktext: vocabArr, + automargin: true, // Otherwise the label is outside of the plot + } + } + return plotly.react(domId, trace, layout) +} diff --git a/tslib/react/src/components/PlotSliceDark.stories.tsx b/tslib/react/src/components/PlotSliceDark.stories.tsx new file mode 100644 index 00000000..b9b67eb7 --- /dev/null +++ b/tslib/react/src/components/PlotSliceDark.stories.tsx @@ -0,0 +1,40 @@ +import { CssBaseline, ThemeProvider } from "@mui/material" +import { Meta, StoryObj } from "@storybook/react" +import React from "react" +import { useMockStudy } from "../MockStudies" +import { darkTheme } from "../styles/darkTheme" +import { PlotSlice } from "./PlotSlice" + +const meta: Meta = { + component: PlotSlice, + title: "PlotSliceDark", + tags: ["autodocs"], + decorators: [ + (Story, storyContext) => { + const { study } = useMockStudy(storyContext.parameters?.studyId) + if (!study) return

loading...

+ return ( + + + + + ) + }, + ], + parameters: { + backgrounds: { default: "dark" }, + }, +} + +export default meta +type Story = StoryObj + +export const MockStudyExample1: Story = { + parameters: { + studyId: 1, + }, +} diff --git a/tslib/react/src/index.ts b/tslib/react/src/index.ts index e1b4243f..71518938 100644 --- a/tslib/react/src/index.ts +++ b/tslib/react/src/index.ts @@ -1,10 +1,11 @@ export { DataGrid } from "./components/DataGrid" export { plotlyDarkTemplate } from "./components/PlotlyDarkMode" -export { PlotEdf, getPlotDomId } from "./components/PlotEdf" +export { PlotEdf } from "./components/PlotEdf" export type { EdfPlotInfo } from "./components/PlotEdf" export { PlotHistory } from "./components/PlotHistory" export { PlotImportance } from "./components/PlotImportance" export { PlotIntermediateValues } from "./components/PlotIntermediateValues" +export { PlotSlice } from "./components/PlotSlice" export { TrialTable } from "./components/TrialTable" export { GraphContainer } from "./components/GraphContainer" export { useGraphComponentState } from "./hooks/useGraphComponentState" diff --git a/tslib/react/test/PlotSlice.test.tsx b/tslib/react/test/PlotSlice.test.tsx new file mode 100644 index 00000000..d0a7a41e --- /dev/null +++ b/tslib/react/test/PlotSlice.test.tsx @@ -0,0 +1,32 @@ +import * as Optuna from "@optuna/types" +import { render, screen } from "@testing-library/react" +import React from "react" +import { describe, expect, test } from "vitest" +import { PlotSlice } from "../src/components/PlotSlice" + +describe("PlotSlice Tests", async () => { + const setup = ({ + study, + dataTestId, + }: { study: Optuna.Study; dataTestId: string }) => { + const Wrapper = ({ + dataTestId, + children, + }: { + dataTestId: string + children: React.ReactNode + }) =>
{children}
+ return render( + + + + ) + } + + for (const study of window.mockStudies) { + test(`PlotSlice (study name: ${study.name})`, () => { + setup({ study, dataTestId: `plot-slice-${study.id}` }) + expect(screen.getByTestId(`plot-slice-${study.id}`)).toBeInTheDocument() + }) + } +})