From 9e5a01564a5ec98618835ff49c80dc05831a4f76 Mon Sep 17 00:00:00 2001 From: Jonas Kulhanek Date: Mon, 29 Jan 2024 14:14:04 +0100 Subject: [PATCH 01/10] Refac components --- src/viser/_gui_api.py | 157 ++-- src/viser/_gui_handles.py | 7 +- src/viser/_messages.py | 31 +- .../client/src/ControlPanel/Generated.tsx | 765 ++---------------- .../src/ControlPanel/GuiComponentContext.tsx | 18 + .../client/src/ControlPanel/GuiState.tsx | 25 +- src/viser/client/src/WebsocketInterface.tsx | 7 +- src/viser/client/src/WebsocketMessages.tsx | 79 +- src/viser/client/src/components/Button.tsx | 69 ++ .../client/src/components/ButtonGroup.tsx | 40 + src/viser/client/src/components/Checkbox.tsx | 49 ++ src/viser/client/src/components/Dropdown.tsx | 36 + src/viser/client/src/components/Folder.tsx | 77 ++ src/viser/client/src/components/Markdown.tsx | 18 + .../client/src/components/NumberInput.tsx | 36 + src/viser/client/src/components/Rgb.tsx | 28 + src/viser/client/src/components/Rgba.tsx | 27 + src/viser/client/src/components/Slider.tsx | 91 +++ src/viser/client/src/components/TabGroup.tsx | 56 ++ src/viser/client/src/components/TextInput.tsx | 28 + src/viser/client/src/components/Vector2.tsx | 23 + src/viser/client/src/components/Vector3.tsx | 23 + src/viser/client/src/components/common.tsx | 147 ++++ src/viser/client/src/components/utils.tsx | 32 + 24 files changed, 1039 insertions(+), 830 deletions(-) create mode 100644 src/viser/client/src/ControlPanel/GuiComponentContext.tsx create mode 100644 src/viser/client/src/components/Button.tsx create mode 100644 src/viser/client/src/components/ButtonGroup.tsx create mode 100644 src/viser/client/src/components/Checkbox.tsx create mode 100644 src/viser/client/src/components/Dropdown.tsx create mode 100644 src/viser/client/src/components/Folder.tsx create mode 100644 src/viser/client/src/components/Markdown.tsx create mode 100644 src/viser/client/src/components/NumberInput.tsx create mode 100644 src/viser/client/src/components/Rgb.tsx create mode 100644 src/viser/client/src/components/Rgba.tsx create mode 100644 src/viser/client/src/components/Slider.tsx create mode 100644 src/viser/client/src/components/TabGroup.tsx create mode 100644 src/viser/client/src/components/TextInput.tsx create mode 100644 src/viser/client/src/components/Vector2.tsx create mode 100644 src/viser/client/src/components/Vector3.tsx create mode 100644 src/viser/client/src/components/common.tsx create mode 100644 src/viser/client/src/components/utils.tsx diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 53e0f46b5..9a8f7f4fc 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -196,6 +196,7 @@ def add_gui_folder( label: str, order: Optional[float] = None, expand_by_default: bool = True, + visible: bool = True, ) -> GuiFolderHandle: """Add a folder, and return a handle that can be used to populate it. @@ -204,6 +205,7 @@ def add_gui_folder( order: Optional ordering, smallest values will be displayed first. expand_by_default: Open the folder by default. Set to False to collapse it by default. + visible: Whether the component is visible. Returns: A handle that can be used as a context to populate the folder. @@ -217,6 +219,7 @@ def add_gui_folder( label=label, container_id=self._get_container_id(), expand_by_default=expand_by_default, + visible=visible, ) ) return GuiFolderHandle( @@ -258,11 +261,13 @@ def add_gui_modal( def add_gui_tab_group( self, order: Optional[float] = None, + visible: bool = True, ) -> GuiTabGroupHandle: """Add a tab group. Args: order: Optional ordering, smallest values will be displayed first. + visible: Whether the component is visible. Returns: A handle that can be used as a context to populate the tab group. @@ -277,6 +282,7 @@ def add_gui_tab_group( _gui_api=self, _container_id=self._get_container_id(), _order=order, + _visible=visible, ) def add_gui_markdown( @@ -284,6 +290,7 @@ def add_gui_markdown( content: str, image_root: Optional[Path] = None, order: Optional[float] = None, + visible: bool = True, ) -> GuiMarkdownHandle: """Add markdown to the GUI. @@ -291,6 +298,7 @@ def add_gui_markdown( content: Markdown content to display. image_root: Optional root directory to resolve relative image paths. order: Optional ordering, smallest values will be displayed first. + visible: Whether the component is visible. Returns: A handle that can be used to interact with the GUI element. @@ -298,7 +306,7 @@ def add_gui_markdown( handle = GuiMarkdownHandle( _gui_api=self, _id=_make_unique_id(), - _visible=True, + _visible=visible, _container_id=self._get_container_id(), _order=_apply_default_order(order), _image_root=image_root, @@ -357,19 +365,19 @@ def add_gui_button( order = _apply_default_order(order) return GuiButtonHandle( self._create_gui_input( - initial_value=False, + value=False, message=_messages.GuiAddButtonMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=False, + value=False, color=color, icon_base64=None if icon is None else base64_from_icon(icon), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, is_button=True, )._impl ) @@ -425,23 +433,23 @@ def add_gui_button_group( Returns: A handle that can be used to interact with the GUI element. """ - initial_value = options[0] + value = options[0] id = _make_unique_id() order = _apply_default_order(order) return GuiButtonGroupHandle( self._create_gui_input( - initial_value, + value, message=_messages.GuiAddButtonGroupMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, options=tuple(options), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, )._impl, ) @@ -467,21 +475,22 @@ def add_gui_checkbox( Returns: A handle that can be used to interact with the GUI element. """ - assert isinstance(initial_value, bool) + value = initial_value + assert isinstance(value, bool) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddCheckboxMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def add_gui_text( @@ -506,21 +515,22 @@ def add_gui_text( Returns: A handle that can be used to interact with the GUI element. """ - assert isinstance(initial_value, str) + value = initial_value + assert isinstance(value, str) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddTextMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def add_gui_number( @@ -552,8 +562,9 @@ def add_gui_number( Returns: A handle that can be used to interact with the GUI element. """ + value = initial_value - assert isinstance(initial_value, (int, float)) + assert isinstance(value, (int, float)) if step is None: # It's ok that `step` is always a float, even if the value is an integer, @@ -561,7 +572,7 @@ def add_gui_number( step = float( # type: ignore onp.min( [ - _compute_step(initial_value), + _compute_step(value), _compute_step(min), _compute_step(max), ] @@ -573,21 +584,21 @@ def add_gui_number( id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value=initial_value, + value, message=_messages.GuiAddNumberMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, min=min, max=max, precision=_compute_precision_digits(step), step=step, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, is_button=False, ) @@ -619,7 +630,8 @@ def add_gui_vector2( Returns: A handle that can be used to interact with the GUI element. """ - initial_value = cast_vector(initial_value, 2) + value = initial_value + value = cast_vector(value, 2) min = cast_vector(min, 2) if min is not None else None max = cast_vector(max, 2) if max is not None else None id = _make_unique_id() @@ -627,7 +639,7 @@ def add_gui_vector2( if step is None: possible_steps: List[float] = [] - possible_steps.extend([_compute_step(x) for x in initial_value]) + possible_steps.extend([_compute_step(x) for x in value]) if min is not None: possible_steps.extend([_compute_step(x) for x in min]) if max is not None: @@ -635,21 +647,21 @@ def add_gui_vector2( step = float(onp.min(possible_steps)) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddVector2Message( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, min=min, max=max, step=step, precision=_compute_precision_digits(step), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def add_gui_vector3( @@ -680,7 +692,8 @@ def add_gui_vector3( Returns: A handle that can be used to interact with the GUI element. """ - initial_value = cast_vector(initial_value, 2) + value = initial_value + value = cast_vector(value, 2) min = cast_vector(min, 3) if min is not None else None max = cast_vector(max, 3) if max is not None else None id = _make_unique_id() @@ -688,7 +701,7 @@ def add_gui_vector3( if step is None: possible_steps: List[float] = [] - possible_steps.extend([_compute_step(x) for x in initial_value]) + possible_steps.extend([_compute_step(x) for x in value]) if min is not None: possible_steps.extend([_compute_step(x) for x in min]) if max is not None: @@ -696,21 +709,21 @@ def add_gui_vector3( step = float(onp.min(possible_steps)) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddVector3Message( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, min=min, max=max, step=step, precision=_compute_precision_digits(step), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) # See add_gui_dropdown for notes on overloads. @@ -764,24 +777,25 @@ def add_gui_dropdown( Returns: A handle that can be used to interact with the GUI element. """ - if initial_value is None: - initial_value = options[0] + value = initial_value + if value is None: + value = options[0] id = _make_unique_id() order = _apply_default_order(order) return GuiDropdownHandle( self._create_gui_input( - initial_value, + value, message=_messages.GuiAddDropdownMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, options=tuple(options), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, )._impl, _impl_options=tuple(options), ) @@ -814,27 +828,28 @@ def add_gui_slider( Returns: A handle that can be used to interact with the GUI element. """ + value: IntOrFloat = initial_value assert max >= min if step > max - min: step = max - min - assert max >= initial_value >= min + assert max >= value >= min # GUI callbacks cast incoming values to match the type of the initial value. If # the min, max, or step is a float, we should cast to a float. - if type(initial_value) is int and ( + if type(value) is int and ( type(min) is float or type(max) is float or type(step) is float ): - initial_value = float(initial_value) # type: ignore + value = float(value) # type: ignore # TODO: as of 6/5/2023, this assert will break something in nerfstudio. (at # least LERF) # - # assert type(min) == type(max) == type(step) == type(initial_value) + # assert type(min) == type(max) == type(step) == type(value) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value=initial_value, + value, message=_messages.GuiAddSliderMessage( order=order, id=id, @@ -844,11 +859,11 @@ def add_gui_slider( min=min, max=max, step=step, - initial_value=initial_value, + value=value, precision=_compute_precision_digits(step), + visible=visible, + disabled=disabled, ), - disabled=disabled, - visible=visible, is_button=False, ) @@ -875,20 +890,21 @@ def add_gui_rgb( A handle that can be used to interact with the GUI element. """ + value = initial_value id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddRgbMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def add_gui_rgba( @@ -913,28 +929,27 @@ def add_gui_rgba( Returns: A handle that can be used to interact with the GUI element. """ + value = initial_value id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddRgbaMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def _create_gui_input( self, - initial_value: T, + value: T, message: _messages._GuiAddInputBase, - disabled: bool, - visible: bool, is_button: bool = False, ) -> GuiInputHandle[T]: """Private helper for adding a simple GUI element.""" @@ -945,19 +960,19 @@ def _create_gui_input( # Construct handle. handle_state = _GuiHandleState( label=message.label, - typ=type(initial_value), + typ=type(value), gui_api=self, - value=initial_value, + value=value, + initial_value=value, update_timestamp=time.time(), container_id=self._get_container_id(), update_cb=[], is_button=is_button, sync_cb=None, - disabled=False, - visible=True, + disabled=message.disabled, + visible=message.visible, id=message.id, order=message.order, - initial_value=initial_value, hint=message.hint, ) @@ -974,10 +989,4 @@ def sync_other_clients(client_id: ClientId, value: Any) -> None: handle = GuiInputHandle(handle_state) - # Set the disabled/visible fields. These will queue messages under-the-hood. - if disabled: - handle.disabled = disabled - if not visible: - handle.visible = visible - return handle diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index f150ae604..aae704412 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -319,8 +319,10 @@ def options(self, options: Iterable[StringType]) -> None: label=self._impl.label, container_id=self._impl.container_id, hint=self._impl.hint, - initial_value=self._impl.initial_value, + value=self._impl.initial_value, options=self._impl_options, + visible=self._impl.visible, + disabled=self._impl.disabled, ) ) @@ -337,6 +339,7 @@ class GuiTabGroupHandle: _gui_api: GuiApi _container_id: str # Parent. _order: float + _visible: bool @property def order(self) -> float: @@ -374,6 +377,7 @@ def _sync_with_client(self) -> None: tab_labels=tuple(self._labels), tab_icons_base64=tuple(self._icons_base64), tab_container_ids=tuple(tab._id for tab in self._tabs), + visible=self._visible, ) ) @@ -567,6 +571,7 @@ def content(self, content: str) -> None: id=self._id, markdown=_parse_markdown(content, self._image_root), container_id=self._container_id, + visible=self._visible, ) ) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 9cf8cd411..fabdea20a 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -353,6 +353,7 @@ class GuiAddFolderMessage(Message): label: str container_id: str expand_by_default: bool + visible: bool @dataclasses.dataclass @@ -361,6 +362,7 @@ class GuiAddMarkdownMessage(Message): id: str markdown: str container_id: str + visible: bool @dataclasses.dataclass @@ -371,6 +373,7 @@ class GuiAddTabGroupMessage(Message): tab_labels: Tuple[str, ...] tab_icons_base64: Tuple[Union[str, None], ...] tab_container_ids: Tuple[str, ...] + visible: bool @dataclasses.dataclass @@ -382,7 +385,9 @@ class _GuiAddInputBase(Message): label: str container_id: str hint: Optional[str] - initial_value: Any + value: Any + visible: bool + disabled: bool @dataclasses.dataclass @@ -399,9 +404,9 @@ class GuiCloseModalMessage(Message): @dataclasses.dataclass class GuiAddButtonMessage(_GuiAddInputBase): - # All GUI elements currently need an `initial_value` field. + # All GUI elements currently need an `value` field. # This makes our job on the frontend easier. - initial_value: bool + value: bool color: Optional[ Literal[ "dark", @@ -428,13 +433,13 @@ class GuiAddSliderMessage(_GuiAddInputBase): min: float max: float step: Optional[float] - initial_value: float + value: float precision: int @dataclasses.dataclass class GuiAddNumberMessage(_GuiAddInputBase): - initial_value: float + value: float precision: int step: float min: Optional[float] @@ -443,22 +448,22 @@ class GuiAddNumberMessage(_GuiAddInputBase): @dataclasses.dataclass class GuiAddRgbMessage(_GuiAddInputBase): - initial_value: Tuple[int, int, int] + value: Tuple[int, int, int] @dataclasses.dataclass class GuiAddRgbaMessage(_GuiAddInputBase): - initial_value: Tuple[int, int, int, int] + value: Tuple[int, int, int, int] @dataclasses.dataclass class GuiAddCheckboxMessage(_GuiAddInputBase): - initial_value: bool + value: bool @dataclasses.dataclass class GuiAddVector2Message(_GuiAddInputBase): - initial_value: Tuple[float, float] + value: Tuple[float, float] min: Optional[Tuple[float, float]] max: Optional[Tuple[float, float]] step: float @@ -467,7 +472,7 @@ class GuiAddVector2Message(_GuiAddInputBase): @dataclasses.dataclass class GuiAddVector3Message(_GuiAddInputBase): - initial_value: Tuple[float, float, float] + value: Tuple[float, float, float] min: Optional[Tuple[float, float, float]] max: Optional[Tuple[float, float, float]] step: float @@ -476,18 +481,18 @@ class GuiAddVector3Message(_GuiAddInputBase): @dataclasses.dataclass class GuiAddTextMessage(_GuiAddInputBase): - initial_value: str + value: str @dataclasses.dataclass class GuiAddDropdownMessage(_GuiAddInputBase): - initial_value: str + value: str options: Tuple[str, ...] @dataclasses.dataclass class GuiAddButtonGroupMessage(_GuiAddInputBase): - initial_value: str + value: str options: Tuple[str, ...] diff --git a/src/viser/client/src/ControlPanel/Generated.tsx b/src/viser/client/src/ControlPanel/Generated.tsx index 9419eb2a1..6e12c131f 100644 --- a/src/viser/client/src/ControlPanel/Generated.tsx +++ b/src/viser/client/src/ControlPanel/Generated.tsx @@ -1,50 +1,30 @@ -import { - GuiAddFolderMessage, - GuiAddTabGroupMessage, -} from "../WebsocketMessages"; -import { ViewerContext, ViewerContextContents } from "../App"; +import { ViewerContext } from "../App"; import { makeThrottledMessageSender } from "../WebsocketFunctions"; -import { computeRelativeLuminance } from "./GuiState"; -import { - Collapse, - Image, - Paper, - Tabs, - TabsValue, - useMantineTheme, -} from "@mantine/core"; +import { GuiConfig } from "./GuiState"; +import { GuiComponentContext } from "./GuiComponentContext"; import { Box, - Button, - Checkbox, - ColorInput, - Flex, - NumberInput, - Select, - Slider, - Text, - TextInput, - Tooltip, } from "@mantine/core"; import React from "react"; -import Markdown from "../Markdown"; -import { ErrorBoundary } from "react-error-boundary"; -import { useDisclosure } from "@mantine/hooks"; -import { IconChevronDown, IconChevronUp } from "@tabler/icons-react"; - -/** Root of generated inputs. */ -export default function GeneratedGuiContainer({ - // We need to take viewer as input in drei's elements, where contexts break. - containerId, - viewer, - folderDepth, -}: { - containerId: string; - viewer?: ViewerContextContents; - folderDepth?: number; -}) { - if (viewer === undefined) viewer = React.useContext(ViewerContext)!; +import ButtonComponent from "../components/Button"; +import SliderComponent from "../components/Slider"; +import NumberInputComponent from "../components/NumberInput"; +import TextInputComponent from "../components/TextInput"; +import CheckboxComponent from "../components/Checkbox"; +import Vector2Component from "../components/Vector2"; +import Vector3Component from "../components/Vector3"; +import DropdownComponent from "../components/Dropdown"; +import RgbComponent from "../components/Rgb"; +import RgbaComponent from "../components/Rgba"; +import ButtonGroupComponent from "../components/ButtonGroup"; +import MarkdownComponent from "../components/Markdown"; +import TabGroupComponent from "../components/TabGroup"; +import FolderComponent from "../components/Folder"; + + +function GuiContainer({ containerId }: { containerId: string }) { + const viewer = React.useContext(ViewerContext)!; const guiIdSet = viewer.useGui((state) => state.guiIdSetFromContainerId[containerId]) ?? {}; @@ -54,685 +34,84 @@ export default function GeneratedGuiContainer({ const guiOrderFromId = viewer!.useGui((state) => state.guiOrderFromId); if (guiIdSet === undefined) return null; - const guiIdOrderPairArray = guiIdArray.map((id) => ({ + let guiIdOrderPairArray = guiIdArray.map((id) => ({ id: id, order: guiOrderFromId[id], })); + let pb = undefined; + guiIdOrderPairArray = guiIdOrderPairArray.sort((a, b) => a.order - b.order); + const inputProps = viewer.useGui((state) => guiIdOrderPairArray.map(pair => state.guiConfigFromId[pair.id])); + const lastProps = inputProps && inputProps[inputProps.length - 1]; + + // Done to match the old behaviour. Is it still needed? + if (lastProps !== undefined && lastProps.type === "GuiAddFolderMessage") { + pb = "0.125em"; + } const out = ( - - {guiIdOrderPairArray - .sort((a, b) => a.order - b.order) - .map((pair, index) => ( - - ))} + + {inputProps.map((conf) => )} ); return out; } -/** A single generated GUI element. */ -function GeneratedInput({ - id, - viewer, - folderDepth, - last, -}: { - id: string; - viewer?: ViewerContextContents; - folderDepth: number; - last: boolean; -}) { - // Handle GUI input types. - if (viewer === undefined) viewer = React.useContext(ViewerContext)!; - const conf = viewer.useGui((state) => state.guiConfigFromId[id]); - - // Handle nested containers. - if (conf.type == "GuiAddFolderMessage") - return ( - - - - ); - if (conf.type == "GuiAddTabGroupMessage") - return ; - if (conf.type == "GuiAddMarkdownMessage") { - let { visible } = - viewer.useGui((state) => state.guiAttributeFromId[conf.id]) || {}; - visible = visible ?? true; - if (!visible) return <>>; - return ( - - Markdown Failed to Render} - > - {conf.markdown} - - - ); - } - +/** Root of generated inputs. */ +export default function GeneratedGuiContainer({ containerId }: { containerId: string; }) { + const viewer = React.useContext(ViewerContext)!; const messageSender = makeThrottledMessageSender(viewer.websocketRef, 50); - function updateValue(value: any) { - setGuiValue(conf.id, value); - messageSender({ type: "GuiUpdateMessage", id: conf.id, value: value }); + function setValue(id: string, value: any) { + setGuiValue(id, value); + messageSender({ type: "GuiUpdateMessage", id: id, value: value }); } const setGuiValue = viewer.useGui((state) => state.setGuiValue); - const value = - viewer.useGui((state) => state.guiValueFromId[conf.id]) ?? - conf.initial_value; - const theme = useMantineTheme(); + return + + - let { visible, disabled } = - viewer.useGui((state) => state.guiAttributeFromId[conf.id]) || {}; - - visible = visible ?? true; - disabled = disabled ?? false; - - if (!visible) return <>>; - - let inputColor = - computeRelativeLuminance(theme.fn.primaryColor()) > 50.0 - ? theme.colors.gray[9] - : theme.white; +} - let labeled = true; - let input = null; +/** A single generated GUI element. */ +function GeneratedInput(conf: GuiConfig) { switch (conf.type) { + case "GuiAddFolderMessage": + return ; + case "GuiAddTabGroupMessage": + return ; + case "GuiAddMarkdownMessage": + return ; case "GuiAddButtonMessage": - labeled = false; - if (conf.color !== null) { - inputColor = - computeRelativeLuminance( - theme.colors[conf.color][theme.fn.primaryShade()], - ) > 50.0 - ? theme.colors.gray[9] - : theme.white; - } - - input = ( - - messageSender({ - type: "GuiUpdateMessage", - id: conf.id, - value: true, - }) - } - style={{ height: "2.125em" }} - styles={{ inner: { color: inputColor + " !important" } }} - disabled={disabled} - size="sm" - leftIcon={ - conf.icon_base64 === null ? undefined : ( - - ) - } - > - {conf.label} - - ); - break; + return ; case "GuiAddSliderMessage": - input = ( - - - ({ - thumb: { - background: theme.fn.primaryColor(), - borderRadius: "0.1em", - height: "0.75em", - width: "0.625em", - }, - })} - pt="0.2em" - showLabelOnHover={false} - min={conf.min} - max={conf.max} - step={conf.step ?? undefined} - precision={conf.precision} - value={value} - onChange={updateValue} - marks={[{ value: conf.min }, { value: conf.max }]} - disabled={disabled} - /> - - {parseInt(conf.min.toFixed(6))} - {parseInt(conf.max.toFixed(6))} - - - { - // Ignore empty values. - newValue !== "" && updateValue(newValue); - }} - size="xs" - min={conf.min} - max={conf.max} - hideControls - step={conf.step ?? undefined} - precision={conf.precision} - sx={{ width: "3rem" }} - styles={{ - input: { - padding: "0.375em", - letterSpacing: "-0.5px", - minHeight: "1.875em", - height: "1.875em", - }, - }} - ml="xs" - /> - - ); - break; + return ; case "GuiAddNumberMessage": - input = ( - { - // Ignore empty values. - newValue !== "" && updateValue(newValue); - }} - styles={{ - input: { - minHeight: "1.625rem", - height: "1.625rem", - }, - }} - disabled={disabled} - stepHoldDelay={500} - stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} - /> - ); - break; + return ; case "GuiAddTextMessage": - input = ( - { - updateValue(value.target.value); - }} - styles={{ - input: { - minHeight: "1.625rem", - height: "1.625rem", - padding: "0 0.5em", - }, - }} - disabled={disabled} - /> - ); - break; + return ; case "GuiAddCheckboxMessage": - input = ( - { - updateValue(value.target.checked); - }} - disabled={disabled} - styles={{ - icon: { - color: inputColor + " !important", - }, - }} - /> - ); - break; + return ; case "GuiAddVector2Message": - input = ( - - ); - break; + return ; case "GuiAddVector3Message": - input = ( - - ); - break; + return ; case "GuiAddDropdownMessage": - input = ( - modal zIndex. - // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. - zIndex={1000} - withinPortal - /> - ); - break; + return ; case "GuiAddRgbMessage": - input = ( - updateValue(hexToRgb(v))} - format="hex" - // zIndex of dropdown should be >modal zIndex. - // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. - dropdownZIndex={1000} - withinPortal - styles={{ - input: { height: "1.625rem", minHeight: "1.625rem" }, - icon: { transform: "scale(0.8)" }, - }} - /> - ); - break; + return ; case "GuiAddRgbaMessage": - input = ( - updateValue(hexToRgba(v))} - format="hexa" - // zIndex of dropdown should be >modal zIndex. - // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. - dropdownZIndex={1000} - withinPortal - styles={{ input: { height: "1.625rem", minHeight: "1.625rem" } }} - /> - ); - break; + return ; case "GuiAddButtonGroupMessage": - input = ( - - {conf.options.map((option, index) => ( - - messageSender({ - type: "GuiUpdateMessage", - id: conf.id, - value: option, - }) - } - style={{ flexGrow: 1, width: 0 }} - disabled={disabled} - compact - size="xs" - variant="outline" - > - {option} - - ))} - - ); + return ; + default: + assertNeverType(conf); } - - if (conf.hint !== null) - input = // We need to add for inputs that we can't assign refs to. - ( - - - {input} - - - ); - - if (labeled) - input = ( - - ); - - return ( - - {input} - - ); } -function GeneratedFolder({ - conf, - folderDepth, - viewer, -}: { - conf: GuiAddFolderMessage; - folderDepth: number; - viewer: ViewerContextContents; -}) { - const [opened, { toggle }] = useDisclosure(conf.expand_by_default); - const guiIdSet = viewer.useGui( - (state) => state.guiIdSetFromContainerId[conf.id], - ); - const isEmpty = guiIdSet === undefined || Object.keys(guiIdSet).length === 0; - - const ToggleIcon = opened ? IconChevronUp : IconChevronDown; - return ( - - - {conf.label} - - - - - - - - - - ); -} - -function GeneratedTabGroup({ conf }: { conf: GuiAddTabGroupMessage }) { - const [tabState, setTabState] = React.useState("0"); - const icons = conf.tab_icons_base64; - - return ( - - - {conf.tab_labels.map((label, index) => ( - ({ - filter: - theme.colorScheme == "dark" ? "invert(1)" : undefined, - })} - src={"data:image/svg+xml;base64," + icons[index]} - /> - ) - } - > - {label} - - ))} - - {conf.tab_container_ids.map((containerId, index) => ( - - - - ))} - - ); -} - -function VectorInput( - props: - | { - id: string; - n: 2; - value: [number, number]; - min: [number, number] | null; - max: [number, number] | null; - step: number; - precision: number; - onChange: (value: number[]) => void; - disabled: boolean; - } - | { - id: string; - n: 3; - value: [number, number, number]; - min: [number, number, number] | null; - max: [number, number, number] | null; - step: number; - precision: number; - onChange: (value: number[]) => void; - disabled: boolean; - }, -) { - return ( - - {[...Array(props.n).keys()].map((i) => ( - { - const updated = [...props.value]; - updated[i] = v === "" ? 0.0 : v; - props.onChange(updated); - }} - size="xs" - styles={{ - root: { flexGrow: 1, width: 0 }, - input: { - paddingLeft: "0.5em", - paddingRight: "1.75em", - textAlign: "right", - minHeight: "1.875em", - height: "1.875em", - }, - rightSection: { width: "1.2em" }, - control: { - width: "1.1em", - }, - }} - precision={props.precision} - step={props.step} - min={props.min === null ? undefined : props.min[i]} - max={props.max === null ? undefined : props.max[i]} - stepHoldDelay={500} - stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} - disabled={props.disabled} - /> - ))} - - ); -} - -/** GUI input with a label horizontally placed to the left of it. */ -function LabeledInput(props: { - id: string; - label: string; - input: React.ReactNode; - folderDepth: number; -}) { - return ( - - - - {props.label} - - - {props.input} - - ); -} - -// Color conversion helpers. - -function rgbToHex([r, g, b]: [number, number, number]): string { - const hexR = r.toString(16).padStart(2, "0"); - const hexG = g.toString(16).padStart(2, "0"); - const hexB = b.toString(16).padStart(2, "0"); - return `#${hexR}${hexG}${hexB}`; -} - -function hexToRgb(hexColor: string): [number, number, number] { - const hex = hexColor.slice(1); // Remove the # in #ffffff. - const r = parseInt(hex.substring(0, 2), 16); - const g = parseInt(hex.substring(2, 4), 16); - const b = parseInt(hex.substring(4, 6), 16); - return [r, g, b]; -} -function rgbaToHex([r, g, b, a]: [number, number, number, number]): string { - const hexR = r.toString(16).padStart(2, "0"); - const hexG = g.toString(16).padStart(2, "0"); - const hexB = b.toString(16).padStart(2, "0"); - const hexA = a.toString(16).padStart(2, "0"); - return `#${hexR}${hexG}${hexB}${hexA}`; -} - -function hexToRgba(hexColor: string): [number, number, number, number] { - const hex = hexColor.slice(1); // Remove the # in #ffffff. - const r = parseInt(hex.substring(0, 2), 16); - const g = parseInt(hex.substring(2, 4), 16); - const b = parseInt(hex.substring(4, 6), 16); - const a = parseInt(hex.substring(6, 8), 16); - return [r, g, b, a]; -} +function assertNeverType(x: never): never { + throw new Error("Unexpected object: " + (x as any).type); +} \ No newline at end of file diff --git a/src/viser/client/src/ControlPanel/GuiComponentContext.tsx b/src/viser/client/src/ControlPanel/GuiComponentContext.tsx new file mode 100644 index 000000000..f566efc1c --- /dev/null +++ b/src/viser/client/src/ControlPanel/GuiComponentContext.tsx @@ -0,0 +1,18 @@ +import * as React from "react"; +import * as Messages from "../WebsocketMessages"; + +interface GuiComponentContext { + folderDepth: number, + setValue: (id: string, value: any) => void, + messageSender: (message: Messages.Message) => void, + GuiContainer: React.FC<{ containerId: string }>, +} + +export const GuiComponentContext = React.createContext({ + folderDepth: 0, + setValue: () => undefined, + messageSender: () => undefined, + GuiContainer: () => { + throw new Error("GuiComponentContext not initialized"); + }, +}); diff --git a/src/viser/client/src/ControlPanel/GuiState.tsx b/src/viser/client/src/ControlPanel/GuiState.tsx index e70dc4e18..1ab9bac2f 100644 --- a/src/viser/client/src/ControlPanel/GuiState.tsx +++ b/src/viser/client/src/ControlPanel/GuiState.tsx @@ -40,10 +40,6 @@ interface GuiState { modals: Messages.GuiModalMessage[]; guiOrderFromId: { [id: string]: number }; guiConfigFromId: { [id: string]: GuiConfig }; - guiValueFromId: { [id: string]: any }; - guiAttributeFromId: { - [id: string]: { visible?: boolean; disabled?: boolean } | undefined; - }; } interface GuiActions { @@ -79,8 +75,6 @@ const cleanGuiState: GuiState = { modals: [], guiOrderFromId: {}, guiConfigFromId: {}, - guiValueFromId: {}, - guiAttributeFromId: {}, }; export function computeRelativeLuminance(color: string) { @@ -130,21 +124,18 @@ export function useGuiState(initialServer: string) { }), setGuiValue: (id, value) => set((state) => { - state.guiValueFromId[id] = value; + const config = state.guiConfigFromId[id] as any; + state.guiConfigFromId[id] = {...config, value} as GuiConfig; }), setGuiVisible: (id, visible) => set((state) => { - state.guiAttributeFromId[id] = { - ...state.guiAttributeFromId[id], - visible: visible, - }; + const config = state.guiConfigFromId[id] as any; + state.guiConfigFromId[id] = {...config, visible} as GuiConfig; }), setGuiDisabled: (id, disabled) => set((state) => { - state.guiAttributeFromId[id] = { - ...state.guiAttributeFromId[id], - disabled: disabled, - }; + const config = state.guiConfigFromId[id] as any; + state.guiConfigFromId[id] = {...config, disabled} as GuiConfig; }), removeGui: (id) => set((state) => { @@ -153,8 +144,6 @@ export function useGuiState(initialServer: string) { delete state.guiIdSetFromContainerId[guiConfig.container_id]![id]; delete state.guiOrderFromId[id]; delete state.guiConfigFromId[id]; - delete state.guiValueFromId[id]; - delete state.guiAttributeFromId[id]; }), resetGui: () => set((state) => { @@ -162,8 +151,6 @@ export function useGuiState(initialServer: string) { state.guiIdSetFromContainerId = {}; state.guiOrderFromId = {}; state.guiConfigFromId = {}; - state.guiValueFromId = {}; - state.guiAttributeFromId = {}; }), })), ), diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 65a292da0..0d60776c8 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -666,10 +666,9 @@ function useMessageHandler() { evt.stopPropagation(); }} > - + + + diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index e59b51859..3f5409416 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -312,7 +312,7 @@ export interface SceneNodeClickMessage { export interface ResetSceneMessage { type: "ResetSceneMessage"; } -/** GuiAddFolderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', expand_by_default: 'bool') +/** GuiAddFolderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', expand_by_default: 'bool', visible: 'bool') * * (automatically generated) */ @@ -323,8 +323,9 @@ export interface GuiAddFolderMessage { label: string; container_id: string; expand_by_default: boolean; + visible: boolean; } -/** GuiAddMarkdownMessage(order: 'float', id: 'str', markdown: 'str', container_id: 'str') +/** GuiAddMarkdownMessage(order: 'float', id: 'str', markdown: 'str', container_id: 'str', visible: 'bool') * * (automatically generated) */ @@ -334,8 +335,9 @@ export interface GuiAddMarkdownMessage { id: string; markdown: string; container_id: string; + visible: boolean; } -/** GuiAddTabGroupMessage(order: 'float', id: 'str', container_id: 'str', tab_labels: 'Tuple[str, ...]', tab_icons_base64: 'Tuple[Union[str, None], ...]', tab_container_ids: 'Tuple[str, ...]') +/** GuiAddTabGroupMessage(order: 'float', id: 'str', container_id: 'str', tab_labels: 'Tuple[str, ...]', tab_icons_base64: 'Tuple[Union[str, None], ...]', tab_container_ids: 'Tuple[str, ...]', visible: 'bool') * * (automatically generated) */ @@ -347,6 +349,7 @@ export interface GuiAddTabGroupMessage { tab_labels: string[]; tab_icons_base64: (string | null)[]; tab_container_ids: string[]; + visible: boolean; } /** Base message type containing fields commonly used by GUI inputs. * @@ -359,9 +362,11 @@ export interface _GuiAddInputBase { label: string; container_id: string; hint: string | null; - initial_value: any; + value: any; + visible: boolean; + disabled: boolean; } -/** GuiAddButtonMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'bool', color: "Optional[Literal['dark', 'gray', 'red', 'pink', 'grape', 'violet', 'indigo', 'blue', 'cyan', 'green', 'lime', 'yellow', 'orange', 'teal']]", icon_base64: 'Optional[str]') +/** GuiAddButtonMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'bool', visible: 'bool', disabled: 'bool', color: "Optional[Literal['dark', 'gray', 'red', 'pink', 'grape', 'violet', 'indigo', 'blue', 'cyan', 'green', 'lime', 'yellow', 'orange', 'teal']]", icon_base64: 'Optional[str]') * * (automatically generated) */ @@ -372,7 +377,9 @@ export interface GuiAddButtonMessage { label: string; container_id: string; hint: string | null; - initial_value: boolean; + value: boolean; + visible: boolean; + disabled: boolean; color: | "dark" | "gray" @@ -391,7 +398,7 @@ export interface GuiAddButtonMessage { | null; icon_base64: string | null; } -/** GuiAddSliderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'float', min: 'float', max: 'float', step: 'Optional[float]', precision: 'int') +/** GuiAddSliderMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'float', visible: 'bool', disabled: 'bool', min: 'float', max: 'float', step: 'Optional[float]', precision: 'int') * * (automatically generated) */ @@ -402,13 +409,15 @@ export interface GuiAddSliderMessage { label: string; container_id: string; hint: string | null; - initial_value: number; + value: number; + visible: boolean; + disabled: boolean; min: number; max: number; step: number | null; precision: number; } -/** GuiAddNumberMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'float', precision: 'int', step: 'float', min: 'Optional[float]', max: 'Optional[float]') +/** GuiAddNumberMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'float', visible: 'bool', disabled: 'bool', precision: 'int', step: 'float', min: 'Optional[float]', max: 'Optional[float]') * * (automatically generated) */ @@ -419,13 +428,15 @@ export interface GuiAddNumberMessage { label: string; container_id: string; hint: string | null; - initial_value: number; + value: number; + visible: boolean; + disabled: boolean; precision: number; step: number; min: number | null; max: number | null; } -/** GuiAddRgbMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'Tuple[int, int, int]') +/** GuiAddRgbMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[int, int, int]', visible: 'bool', disabled: 'bool') * * (automatically generated) */ @@ -436,9 +447,11 @@ export interface GuiAddRgbMessage { label: string; container_id: string; hint: string | null; - initial_value: [number, number, number]; + value: [number, number, number]; + visible: boolean; + disabled: boolean; } -/** GuiAddRgbaMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'Tuple[int, int, int, int]') +/** GuiAddRgbaMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[int, int, int, int]', visible: 'bool', disabled: 'bool') * * (automatically generated) */ @@ -449,9 +462,11 @@ export interface GuiAddRgbaMessage { label: string; container_id: string; hint: string | null; - initial_value: [number, number, number, number]; + value: [number, number, number, number]; + visible: boolean; + disabled: boolean; } -/** GuiAddCheckboxMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'bool') +/** GuiAddCheckboxMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'bool', visible: 'bool', disabled: 'bool') * * (automatically generated) */ @@ -462,9 +477,11 @@ export interface GuiAddCheckboxMessage { label: string; container_id: string; hint: string | null; - initial_value: boolean; + value: boolean; + visible: boolean; + disabled: boolean; } -/** GuiAddVector2Message(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'Tuple[float, float]', min: 'Optional[Tuple[float, float]]', max: 'Optional[Tuple[float, float]]', step: 'float', precision: 'int') +/** GuiAddVector2Message(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[float, float]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float]]', max: 'Optional[Tuple[float, float]]', step: 'float', precision: 'int') * * (automatically generated) */ @@ -475,13 +492,15 @@ export interface GuiAddVector2Message { label: string; container_id: string; hint: string | null; - initial_value: [number, number]; + value: [number, number]; + visible: boolean; + disabled: boolean; min: [number, number] | null; max: [number, number] | null; step: number; precision: number; } -/** GuiAddVector3Message(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'Tuple[float, float, float]', min: 'Optional[Tuple[float, float, float]]', max: 'Optional[Tuple[float, float, float]]', step: 'float', precision: 'int') +/** GuiAddVector3Message(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'Tuple[float, float, float]', visible: 'bool', disabled: 'bool', min: 'Optional[Tuple[float, float, float]]', max: 'Optional[Tuple[float, float, float]]', step: 'float', precision: 'int') * * (automatically generated) */ @@ -492,13 +511,15 @@ export interface GuiAddVector3Message { label: string; container_id: string; hint: string | null; - initial_value: [number, number, number]; + value: [number, number, number]; + visible: boolean; + disabled: boolean; min: [number, number, number] | null; max: [number, number, number] | null; step: number; precision: number; } -/** GuiAddTextMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'str') +/** GuiAddTextMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool') * * (automatically generated) */ @@ -509,9 +530,11 @@ export interface GuiAddTextMessage { label: string; container_id: string; hint: string | null; - initial_value: string; + value: string; + visible: boolean; + disabled: boolean; } -/** GuiAddDropdownMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'str', options: 'Tuple[str, ...]') +/** GuiAddDropdownMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') * * (automatically generated) */ @@ -522,10 +545,12 @@ export interface GuiAddDropdownMessage { label: string; container_id: string; hint: string | null; - initial_value: string; + value: string; + visible: boolean; + disabled: boolean; options: string[]; } -/** GuiAddButtonGroupMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', initial_value: 'str', options: 'Tuple[str, ...]') +/** GuiAddButtonGroupMessage(order: 'float', id: 'str', label: 'str', container_id: 'str', hint: 'Optional[str]', value: 'str', visible: 'bool', disabled: 'bool', options: 'Tuple[str, ...]') * * (automatically generated) */ @@ -536,7 +561,9 @@ export interface GuiAddButtonGroupMessage { label: string; container_id: string; hint: string | null; - initial_value: string; + value: string; + visible: boolean; + disabled: boolean; options: string[]; } /** GuiModalMessage(order: 'float', id: 'str', title: 'str') diff --git a/src/viser/client/src/components/Button.tsx b/src/viser/client/src/components/Button.tsx new file mode 100644 index 000000000..851582a0b --- /dev/null +++ b/src/viser/client/src/components/Button.tsx @@ -0,0 +1,69 @@ +import { + GuiAddButtonMessage, +} from "../WebsocketMessages"; +import { computeRelativeLuminance } from "../ControlPanel/GuiState"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { + Box, + Image, + useMantineTheme, +} from "@mantine/core"; + +import { Button } from "@mantine/core"; +import React from "react"; + + +export default function ButtonComponent({ id, visible, disabled, label, ...otherProps }: GuiAddButtonMessage) { + const { messageSender } = React.useContext(GuiComponentContext)!; + const theme = useMantineTheme(); + const { color, icon_base64 } = otherProps; + if (!(visible ?? true)) return <>>; + + const inputColor = + computeRelativeLuminance(theme.fn.primaryColor()) > 50.0 + ? theme.colors.gray[9] + : theme.white; + return ( + + + messageSender({ + type: "GuiUpdateMessage", + id: id, + value: true, + }) + } + style={{ height: "2.125em" }} + styles={{ inner: { color: inputColor + " !important" } }} + disabled={disabled ?? false} + size="sm" + leftIcon={ + icon_base64 === null ? undefined : ( + + ) + } + > + {label} + + + ); +} \ No newline at end of file diff --git a/src/viser/client/src/components/ButtonGroup.tsx b/src/viser/client/src/components/ButtonGroup.tsx new file mode 100644 index 000000000..ce747ef2f --- /dev/null +++ b/src/viser/client/src/components/ButtonGroup.tsx @@ -0,0 +1,40 @@ +import * as React from "react"; +import { Button, Flex } from "@mantine/core"; +import { ViserInputComponent } from "./common"; +import { GuiAddButtonGroupMessage } from "../WebsocketMessages"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; + +export default function ButtonGroupComponent({ + id, + hint, + label, + visible, + disabled, + options, +}: GuiAddButtonGroupMessage) { + const { messageSender } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + return + + {options.map((option, index) => ( + + messageSender({ + type: "GuiUpdateMessage", + id: id, + value: option, + }) + } + style={{ flexGrow: 1, width: 0 }} + disabled={disabled} + compact + size="xs" + variant="outline" + > + {option} + + ))} + + ; +} \ No newline at end of file diff --git a/src/viser/client/src/components/Checkbox.tsx b/src/viser/client/src/components/Checkbox.tsx new file mode 100644 index 000000000..5229bc4cc --- /dev/null +++ b/src/viser/client/src/components/Checkbox.tsx @@ -0,0 +1,49 @@ +import * as React from "react"; +import { ViserInputComponent } from "./common"; +import { computeRelativeLuminance } from "../ControlPanel/GuiState"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { GuiAddCheckboxMessage } from "../WebsocketMessages"; +import { Box, Checkbox, Tooltip, useMantineTheme } from "@mantine/core"; + +export default function CheckboxComponent({ id, disabled, visible, hint, label, value }: GuiAddCheckboxMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + const theme = useMantineTheme(); + const inputColor = + computeRelativeLuminance(theme.fn.primaryColor()) > 50.0 + ? theme.colors.gray[9] + : theme.white; + let input = { + setValue(id, value.target.checked); + }} + disabled={disabled} + styles={{ + icon: { + color: inputColor + " !important", + }, + }} + /> + if (hint !== null && hint !== undefined) { + // For checkboxes, we want to make sure that the wrapper + // doesn't expand to the full width of the parent. This will + // de-center the tooltip. + input = + + {input} + + + } + return {input}; +} \ No newline at end of file diff --git a/src/viser/client/src/components/Dropdown.tsx b/src/viser/client/src/components/Dropdown.tsx new file mode 100644 index 000000000..121730b4c --- /dev/null +++ b/src/viser/client/src/components/Dropdown.tsx @@ -0,0 +1,36 @@ +import * as React from "react"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { ViserInputComponent } from "./common"; +import { GuiAddDropdownMessage } from "../WebsocketMessages"; +import { Select } from "@mantine/core"; + + +export default function DropdownComponent({ id, hint, label, value, disabled, visible, options }: GuiAddDropdownMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + return + setValue(id, value)} + disabled={disabled} + searchable + maxDropdownHeight={400} + size="xs" + styles={{ + input: { + padding: "0.5em", + letterSpacing: "-0.5px", + minHeight: "1.625rem", + height: "1.625rem", + }, + }} + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + zIndex={1000} + withinPortal + /> + ; +} \ No newline at end of file diff --git a/src/viser/client/src/components/Folder.tsx b/src/viser/client/src/components/Folder.tsx new file mode 100644 index 000000000..3620f60ac --- /dev/null +++ b/src/viser/client/src/components/Folder.tsx @@ -0,0 +1,77 @@ +import * as React from "react"; +import { useDisclosure } from "@mantine/hooks"; +import { GuiAddFolderMessage } from "../WebsocketMessages"; +import { IconChevronDown, IconChevronUp } from "@tabler/icons-react"; +import { Box, Collapse, Paper } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { ViewerContext } from "../App"; + + +export default function FolderComponent({ + id, + label, + visible, + expand_by_default, +}: GuiAddFolderMessage) { + const viewer = React.useContext(ViewerContext)!; + const [opened, { toggle }] = useDisclosure(expand_by_default); + const guiIdSet = viewer.useGui( + (state) => state.guiIdSetFromContainerId[id], + ); + const guiContext = React.useContext(GuiComponentContext)!; + const isEmpty = guiIdSet === undefined || Object.keys(guiIdSet).length === 0; + + const ToggleIcon = opened ? IconChevronUp : IconChevronDown; + if (!visible) return <>>; + return ( + + + {label} + + + + + + + + + + + + ); +} \ No newline at end of file diff --git a/src/viser/client/src/components/Markdown.tsx b/src/viser/client/src/components/Markdown.tsx new file mode 100644 index 000000000..da1e79682 --- /dev/null +++ b/src/viser/client/src/components/Markdown.tsx @@ -0,0 +1,18 @@ +import { Box, Text } from "@mantine/core"; +import Markdown from "../Markdown"; +import { ErrorBoundary } from "react-error-boundary"; +import { GuiAddMarkdownMessage } from "../WebsocketMessages"; + + +export default function MarkdownComponent({ visible, markdown }: GuiAddMarkdownMessage) { + if (!visible) return <>>; + return ( + + Markdown Failed to Render} + > + {markdown} + + + ); +} \ No newline at end of file diff --git a/src/viser/client/src/components/NumberInput.tsx b/src/viser/client/src/components/NumberInput.tsx new file mode 100644 index 000000000..5056601b9 --- /dev/null +++ b/src/viser/client/src/components/NumberInput.tsx @@ -0,0 +1,36 @@ +import * as React from "react"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { GuiAddNumberMessage } from "../WebsocketMessages"; +import { ViserInputComponent } from "./common"; +import { NumberInput } from "@mantine/core"; + + +export default function NumberInputComponent({ visible, id, label, hint, value, disabled, ...otherProps }: GuiAddNumberMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + const { precision, min, max, step } = otherProps; + if (!visible) return <>>; + return + { + // Ignore empty values. + newValue !== "" && setValue(id, newValue); + }} + styles={{ + input: { + minHeight: "1.625rem", + height: "1.625rem", + }, + }} + disabled={disabled} + stepHoldDelay={500} + stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} + /> + ; +} \ No newline at end of file diff --git a/src/viser/client/src/components/Rgb.tsx b/src/viser/client/src/components/Rgb.tsx new file mode 100644 index 000000000..f5c5fbbe6 --- /dev/null +++ b/src/viser/client/src/components/Rgb.tsx @@ -0,0 +1,28 @@ +import * as React from "react"; +import { ColorInput } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { rgbToHex, hexToRgb } from "./utils"; +import { ViserInputComponent } from "./common"; +import { GuiAddRgbMessage } from "../WebsocketMessages"; + +export default function RgbComponent({ id, label, hint, value, disabled, visible }: GuiAddRgbMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + return + setValue(id, hexToRgb(v))} + format="hex" + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + dropdownZIndex={1000} + withinPortal + styles={{ + input: { height: "1.625rem", minHeight: "1.625rem" }, + icon: { transform: "scale(0.8)" }, + }} + /> + ; +} \ No newline at end of file diff --git a/src/viser/client/src/components/Rgba.tsx b/src/viser/client/src/components/Rgba.tsx new file mode 100644 index 000000000..755a5e51c --- /dev/null +++ b/src/viser/client/src/components/Rgba.tsx @@ -0,0 +1,27 @@ +import * as React from "react"; +import { ColorInput } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { rgbaToHex, hexToRgba } from "./utils"; +import { ViserInputComponent } from "./common"; +import { GuiAddRgbaMessage } from "../WebsocketMessages"; + +export default function RgbaComponent({ id, label, hint, value, disabled, visible }: GuiAddRgbaMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + return + setValue(id, hexToRgba(v))} + format="hexa" + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + dropdownZIndex={1000} + withinPortal + styles={{ + input: { height: "1.625rem", minHeight: "1.625rem" }, + }} + /> + ; +} \ No newline at end of file diff --git a/src/viser/client/src/components/Slider.tsx b/src/viser/client/src/components/Slider.tsx new file mode 100644 index 000000000..02bf300c3 --- /dev/null +++ b/src/viser/client/src/components/Slider.tsx @@ -0,0 +1,91 @@ +import React from "react"; +import { GuiAddSliderMessage } from "../WebsocketMessages"; +import { + Slider, + Box, + Flex, + Text, + NumberInput, +} from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { ViserInputComponent } from "./common"; + + + +export default function SliderComponent({ id, label, hint, visible, disabled, value, ...otherProps }: GuiAddSliderMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + const updateValue = (value: number) => setValue(id, value); + const { min, max, precision, step } = otherProps; + let input = ( + + + ({ + thumb: { + background: theme.fn.primaryColor(), + borderRadius: "0.1em", + height: "0.75em", + width: "0.625em", + }, + })} + pt="0.2em" + showLabelOnHover={false} + min={min} + max={max} + step={step ?? undefined} + precision={precision} + value={value} + onChange={updateValue} + marks={[{ value: min }, { value: max }]} + disabled={disabled} + /> + + {parseInt(min.toFixed(6))} + {parseInt(max.toFixed(6))} + + + { + // Ignore empty values. + newValue !== "" && updateValue(newValue); + }} + size="xs" + min={min} + max={max} + hideControls + step={step ?? undefined} + precision={precision} + sx={{ width: "3rem" }} + styles={{ + input: { + padding: "0.375em", + letterSpacing: "-0.5px", + minHeight: "1.875em", + height: "1.875em", + }, + }} + ml="xs" + /> + + ); + + const containerProps = {}; + // if (marks?.some(x => x.label)) + // containerProps = { ...containerProps, "mb": "md" }; + + input = {input} + return {input}; +} \ No newline at end of file diff --git a/src/viser/client/src/components/TabGroup.tsx b/src/viser/client/src/components/TabGroup.tsx new file mode 100644 index 000000000..11f21a3f5 --- /dev/null +++ b/src/viser/client/src/components/TabGroup.tsx @@ -0,0 +1,56 @@ +import * as React from "react"; +import { GuiAddTabGroupMessage } from "../WebsocketMessages"; +import { Tabs, TabsValue } from "@mantine/core"; +import { Image } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; + + +export default function TabGroupComponent({ + tab_labels, + tab_icons_base64, + tab_container_ids, + visible, +}: GuiAddTabGroupMessage) { + const [tabState, setTabState] = React.useState("0"); + const icons = tab_icons_base64; + const { GuiContainer } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + return ( + + + {tab_labels.map((label, index) => ( + ({ + filter: + theme.colorScheme == "dark" ? "invert(1)" : undefined, + })} + src={"data:image/svg+xml;base64," + icons[index]} + /> + ) + } + > + {label} + + ))} + + {tab_container_ids.map((containerId, index) => ( + + + + ))} + + ); +} \ No newline at end of file diff --git a/src/viser/client/src/components/TextInput.tsx b/src/viser/client/src/components/TextInput.tsx new file mode 100644 index 000000000..671354162 --- /dev/null +++ b/src/viser/client/src/components/TextInput.tsx @@ -0,0 +1,28 @@ +import * as React from "react"; +import { TextInput } from "@mantine/core"; +import { ViserInputComponent } from "./common"; +import { GuiAddTextMessage } from "../WebsocketMessages"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; + +export default function TextInputComponent(props: GuiAddTextMessage) { + const { id, hint, label, value, disabled, visible } = props; + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + return + { + setValue(id, value.target.value); + }} + styles={{ + input: { + minHeight: "1.625rem", + height: "1.625rem", + padding: "0 0.5em", + }, + }} + disabled={disabled} + /> + ; +} \ No newline at end of file diff --git a/src/viser/client/src/components/Vector2.tsx b/src/viser/client/src/components/Vector2.tsx new file mode 100644 index 000000000..1c98276c6 --- /dev/null +++ b/src/viser/client/src/components/Vector2.tsx @@ -0,0 +1,23 @@ +import * as React from "react"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { GuiAddVector2Message } from "../WebsocketMessages"; +import { VectorInput, ViserInputComponent } from "./common"; + +export default function Vector2Component({ id, hint, label, visible, disabled, value, ...otherProps }: GuiAddVector2Message) { + const { min, max, step, precision } = otherProps; + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + return + setValue(id, value)} + min={min} + max={max} + step={step} + precision={precision} + disabled={disabled} + /> + ; +} \ No newline at end of file diff --git a/src/viser/client/src/components/Vector3.tsx b/src/viser/client/src/components/Vector3.tsx new file mode 100644 index 000000000..42cb569df --- /dev/null +++ b/src/viser/client/src/components/Vector3.tsx @@ -0,0 +1,23 @@ +import * as React from "react"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { GuiAddVector3Message } from "../WebsocketMessages"; +import { VectorInput, ViserInputComponent } from "./common"; + +export default function Vector3Component({ id, hint, label, visible, disabled, value, ...otherProps }: GuiAddVector3Message) { + const { min, max, step, precision } = otherProps; + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + return + setValue(id, value)} + min={min} + max={max} + step={step} + precision={precision} + disabled={disabled} + /> + ; +} \ No newline at end of file diff --git a/src/viser/client/src/components/common.tsx b/src/viser/client/src/components/common.tsx new file mode 100644 index 000000000..5a9bfac6b --- /dev/null +++ b/src/viser/client/src/components/common.tsx @@ -0,0 +1,147 @@ +import * as React from 'react'; +import { + Box, + Flex, + Text, + NumberInput, + Tooltip, +} from '@mantine/core'; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; + +export function ViserInputComponent({ id, label, hint, children }: { id: string, children: React.ReactNode, label?: string, hint?: string | null }) { + const { folderDepth } = React.useContext(GuiComponentContext)!; + if (hint !== undefined && hint !== null) { + children = // We need to add for inputs that we can't assign refs to. + ( + + {children} + + ); + } + + if (label !== undefined) + children = ( + + ); + + return ( + + {children} + + ); +} + + +/** GUI input with a label horizontally placed to the left of it. */ +function LabeledInput(props: { + id: string; + label: string; + input: React.ReactNode; + folderDepth: number; +}) { + return ( + + + + {props.label} + + + {props.input} + + ); +} + + +export function VectorInput( + props: + | { + id: string; + n: 2; + value: [number, number]; + min: [number, number] | null; + max: [number, number] | null; + step: number; + precision: number; + onChange: (value: number[]) => void; + disabled: boolean; + } + | { + id: string; + n: 3; + value: [number, number, number]; + min: [number, number, number] | null; + max: [number, number, number] | null; + step: number; + precision: number; + onChange: (value: number[]) => void; + disabled: boolean; + }, +) { + return ( + + {[...Array(props.n).keys()].map((i) => ( + { + const updated = [...props.value]; + updated[i] = v === "" ? 0.0 : v; + props.onChange(updated); + }} + size="xs" + styles={{ + root: { flexGrow: 1, width: 0 }, + input: { + paddingLeft: "0.5em", + paddingRight: "1.75em", + textAlign: "right", + minHeight: "1.875em", + height: "1.875em", + }, + rightSection: { width: "1.2em" }, + control: { + width: "1.1em", + }, + }} + precision={props.precision} + step={props.step} + min={props.min === null ? undefined : props.min[i]} + max={props.max === null ? undefined : props.max[i]} + stepHoldDelay={500} + stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} + disabled={props.disabled} + /> + ))} + + ); +} diff --git a/src/viser/client/src/components/utils.tsx b/src/viser/client/src/components/utils.tsx new file mode 100644 index 000000000..0ec064548 --- /dev/null +++ b/src/viser/client/src/components/utils.tsx @@ -0,0 +1,32 @@ +// Color conversion helpers. + +export function rgbToHex([r, g, b]: [number, number, number]): string { + const hexR = r.toString(16).padStart(2, "0"); + const hexG = g.toString(16).padStart(2, "0"); + const hexB = b.toString(16).padStart(2, "0"); + return `#${hexR}${hexG}${hexB}`; +} + +export function hexToRgb(hexColor: string): [number, number, number] { + const hex = hexColor.slice(1); // Remove the # in #ffffff. + const r = parseInt(hex.substring(0, 2), 16); + const g = parseInt(hex.substring(2, 4), 16); + const b = parseInt(hex.substring(4, 6), 16); + return [r, g, b]; +} +export function rgbaToHex([r, g, b, a]: [number, number, number, number]): string { + const hexR = r.toString(16).padStart(2, "0"); + const hexG = g.toString(16).padStart(2, "0"); + const hexB = b.toString(16).padStart(2, "0"); + const hexA = a.toString(16).padStart(2, "0"); + return `#${hexR}${hexG}${hexB}${hexA}`; +} + +export function hexToRgba(hexColor: string): [number, number, number, number] { + const hex = hexColor.slice(1); // Remove the # in #ffffff. + const r = parseInt(hex.substring(0, 2), 16); + const g = parseInt(hex.substring(2, 4), 16); + const b = parseInt(hex.substring(4, 6), 16); + const a = parseInt(hex.substring(6, 8), 16); + return [r, g, b, a]; +} \ No newline at end of file From 1ac0ebb37cecebb4d5b53c7fbee21170066a28e8 Mon Sep 17 00:00:00 2001 From: Jonas Kulhanek Date: Wed, 31 Jan 2024 15:53:07 +0100 Subject: [PATCH 02/10] GUI api using partial update messages --- src/viser/_gui_api.py | 3 +- src/viser/_gui_handles.py | 47 ++++---- src/viser/_messages.py | 103 +++++++++++++----- .../client/src/ControlPanel/Generated.tsx | 6 +- .../client/src/ControlPanel/GuiState.tsx | 45 ++------ src/viser/client/src/WebsocketInterface.tsx | 22 +--- src/viser/client/src/WebsocketMessages.tsx | 66 ++++++----- src/viser/client/src/components/Button.tsx | 3 +- .../client/src/components/ButtonGroup.tsx | 2 + src/viser/infra/_messages.py | 36 +++--- src/viser/infra/_typescript_interface_gen.py | 16 +++ 11 files changed, 188 insertions(+), 161 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 9a8f7f4fc..f605ad0be 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -960,6 +960,7 @@ def _create_gui_input( # Construct handle. handle_state = _GuiHandleState( label=message.label, + message_type=type(message), typ=type(value), gui_api=self, value=value, @@ -981,7 +982,7 @@ def _create_gui_input( if not is_button: def sync_other_clients(client_id: ClientId, value: Any) -> None: - message = _messages.GuiSetValueMessage(id=handle_state.id, value=value) + message = _messages.GuiUpdateMessage(handle_state.id, handle_state.message_type, value=value) message.excluded_self_client = client_id self._get_api()._queue(message) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index aae704412..ff6219977 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -20,6 +20,7 @@ Type, TypeVar, Union, + ClassVar, ) import imageio.v3 as iio @@ -30,14 +31,12 @@ from ._icons_enum import IconName from ._message_api import _encode_image_base64 from ._messages import ( - GuiAddDropdownMessage, - GuiAddMarkdownMessage, GuiAddTabGroupMessage, GuiCloseModalMessage, GuiRemoveMessage, - GuiSetDisabledMessage, - GuiSetValueMessage, - GuiSetVisibleMessage, + GuiUpdateMessage, + GuiAddMarkdownMessage, + Message, ) from .infra import ClientId @@ -96,6 +95,8 @@ class _GuiHandleState(Generic[T]): initial_value: T hint: Optional[str] + message_type: Type[Message] + @dataclasses.dataclass class _GuiInputHandle(Generic[T]): @@ -138,7 +139,7 @@ def value(self, value: Union[T, onp.ndarray]) -> None: # Send to client, except for buttons. if not self._impl.is_button: self._impl.gui_api._get_api()._queue( - GuiSetValueMessage(self._impl.id, value) # type: ignore + GuiUpdateMessage(self._impl.id, self._impl.message_type, value=value) # type: ignore ) # Set internal state. We automatically convert numpy arrays to the expected @@ -177,7 +178,7 @@ def disabled(self, disabled: bool) -> None: return self._impl.gui_api._get_api()._queue( - GuiSetDisabledMessage(self._impl.id, disabled=disabled) + GuiUpdateMessage(self._impl.id, self._impl.message_type, disabled=disabled) ) self._impl.disabled = disabled @@ -193,7 +194,7 @@ def visible(self, visible: bool) -> None: return self._impl.gui_api._get_api()._queue( - GuiSetVisibleMessage(self._impl.id, visible=visible) + GuiUpdateMessage(self._impl.id, self._impl.message_type, visible=visible) ) self._impl.visible = visible @@ -313,16 +314,10 @@ def options(self, options: Iterable[StringType]) -> None: self._impl.initial_value = self._impl_options[0] self._impl.gui_api._get_api()._queue( - GuiAddDropdownMessage( - order=self._impl.order, - id=self._impl.id, - label=self._impl.label, - container_id=self._impl.container_id, - hint=self._impl.hint, - value=self._impl.initial_value, + GuiUpdateMessage( + self._impl.id, + self._impl.message_type, options=self._impl_options, - visible=self._impl.visible, - disabled=self._impl.disabled, ) ) @@ -370,14 +365,12 @@ def remove(self) -> None: def _sync_with_client(self) -> None: """Send a message that syncs tab state with the client.""" self._gui_api._get_api()._queue( - GuiAddTabGroupMessage( - order=self.order, - id=self._tab_group_id, - container_id=self._container_id, + GuiUpdateMessage( + self._tab_group_id, + GuiAddTabGroupMessage, tab_labels=tuple(self._labels), tab_icons_base64=tuple(self._icons_base64), tab_container_ids=tuple(tab._id for tab in self._tabs), - visible=self._visible, ) ) @@ -566,12 +559,10 @@ def content(self) -> str: def content(self, content: str) -> None: self._content = content self._gui_api._get_api()._queue( - GuiAddMarkdownMessage( - order=self._order, - id=self._id, + GuiUpdateMessage( + self._id, + GuiAddMarkdownMessage, markdown=_parse_markdown(content, self._image_root), - container_id=self._container_id, - visible=self._visible, ) ) @@ -591,7 +582,7 @@ def visible(self, visible: bool) -> None: if visible == self.visible: return - self._gui_api._get_api()._queue(GuiSetVisibleMessage(self._id, visible=visible)) + self._gui_api._get_api()._queue(GuiUpdateMessage(self._id, GuiAddMarkdownMessage, visible=visible)) self._visible = visible def __post_init__(self) -> None: diff --git a/src/viser/_messages.py b/src/viser/_messages.py index fabdea20a..6112b2bc1 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -4,16 +4,20 @@ from __future__ import annotations import dataclasses -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union, TypeVar, Callable, Type, Dict, cast import numpy as onp import numpy.typing as onpt from typing_extensions import Literal, override +import msgpack from . import infra, theme class Message(infra.Message): + _tags: Tuple[str, ...] = tuple() + + @override def redundancy_key(self) -> str: """Returns a unique key for this message, used for detecting redundant @@ -37,6 +41,21 @@ def redundancy_key(self) -> str: return "_".join(parts) +T = TypeVar("T", bound=Message) + + +def tag_class(tag: str) -> Callable[[T], T]: + """Decorator for tagging a class with a `type` field.""" + + def wrapper(cls: T) -> T: + cls._tags = (cls._tags or ()) + (tag,) + return cls + + return wrapper + + + + @dataclasses.dataclass class ViewerCameraMessage(Message): """Message for a posed viewer camera. @@ -346,6 +365,7 @@ class ResetSceneMessage(Message): """Reset scene.""" +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddFolderMessage(Message): order: float @@ -356,6 +376,7 @@ class GuiAddFolderMessage(Message): visible: bool +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddMarkdownMessage(Message): order: float @@ -365,6 +386,7 @@ class GuiAddMarkdownMessage(Message): visible: bool +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddTabGroupMessage(Message): order: float @@ -402,6 +424,7 @@ class GuiCloseModalMessage(Message): id: str +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddButtonMessage(_GuiAddInputBase): # All GUI elements currently need an `value` field. @@ -428,6 +451,7 @@ class GuiAddButtonMessage(_GuiAddInputBase): icon_base64: Optional[str] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddSliderMessage(_GuiAddInputBase): min: float @@ -437,6 +461,7 @@ class GuiAddSliderMessage(_GuiAddInputBase): precision: int +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddNumberMessage(_GuiAddInputBase): value: float @@ -446,21 +471,25 @@ class GuiAddNumberMessage(_GuiAddInputBase): max: Optional[float] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddRgbMessage(_GuiAddInputBase): value: Tuple[int, int, int] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddRgbaMessage(_GuiAddInputBase): value: Tuple[int, int, int, int] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddCheckboxMessage(_GuiAddInputBase): value: bool +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddVector2Message(_GuiAddInputBase): value: Tuple[float, float] @@ -470,6 +499,7 @@ class GuiAddVector2Message(_GuiAddInputBase): precision: int +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddVector3Message(_GuiAddInputBase): value: Tuple[float, float, float] @@ -479,17 +509,20 @@ class GuiAddVector3Message(_GuiAddInputBase): precision: int +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddTextMessage(_GuiAddInputBase): value: str +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddDropdownMessage(_GuiAddInputBase): value: str options: Tuple[str, ...] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddButtonGroupMessage(_GuiAddInputBase): value: str @@ -503,37 +536,49 @@ class GuiRemoveMessage(Message): id: str -@dataclasses.dataclass class GuiUpdateMessage(Message): - """Sent client->server when a GUI input is changed.""" - - id: str - value: Any + """Sent client<->server when a GUI component is changed.""" + + def __init__(self, id: str, type: Type[Message], **changes): + self.id = id + self._type = type + for k, v in changes.items(): + setattr(self, k, v) + + def as_serializable_dict(self) -> Dict[str, Any]: + """Convert a Python Message object into bytes.""" + from viser.infra._messages import get_type_hints_cached, _prepare_for_serialization + hints = get_type_hints_cached(self._type) + mapping = { + k: _prepare_for_serialization(v, hints[k]) for k, v in vars(self).items() if k != "_type" + } + mapping["component_type"] = self._type.__name__ + mapping["type"] = type(self).__name__ + return mapping + + @classmethod + def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]: + mapping["type"] = mapping.pop("component_type") + message_type = Message._subclass_from_type_string()[cast(str, mapping.pop("type"))] + kwargs = message_type._from_serializable_dict(mapping) + kwargs["type"] = message_type + return kwargs + + @classmethod + def _get_ts_type(cls): + return f''' +type _GuiComponentPropsPartial = T extends T ? Partial>: never; +export type GuiComponentPropsPartial = _GuiComponentPropsPartial; +export type _GuiComponentNames = T extends GuiAddComponentMessage ? T["type"] : never; +export type GuiComponentNames = _GuiComponentNames; +export type {cls.__name__} = {{ + id: string; + type: "{cls.__name__}"; + component_type: GuiComponentNames; +}} & GuiComponentPropsPartial; +''' -@dataclasses.dataclass -class GuiSetVisibleMessage(Message): - """Sent client->server when a GUI input is changed.""" - - id: str - visible: bool - - -@dataclasses.dataclass -class GuiSetDisabledMessage(Message): - """Sent client->server when a GUI input is changed.""" - - id: str - disabled: bool - - -@dataclasses.dataclass -class GuiSetValueMessage(Message): - """Sent server->client to set the value of a particular input.""" - - id: str - value: Any - @dataclasses.dataclass class ThemeConfigurationMessage(Message): diff --git a/src/viser/client/src/ControlPanel/Generated.tsx b/src/viser/client/src/ControlPanel/Generated.tsx index 6e12c131f..b430dfe38 100644 --- a/src/viser/client/src/ControlPanel/Generated.tsx +++ b/src/viser/client/src/ControlPanel/Generated.tsx @@ -60,11 +60,11 @@ export default function GeneratedGuiContainer({ containerId }: { containerId: st const viewer = React.useContext(ViewerContext)!; const messageSender = makeThrottledMessageSender(viewer.websocketRef, 50); function setValue(id: string, value: any) { - setGuiValue(id, value); - messageSender({ type: "GuiUpdateMessage", id: id, value: value }); + const { type } = updateGuiProps(id, { value }); + messageSender({ type: "GuiUpdateMessage", component_type: type, id, value }); } - const setGuiValue = viewer.useGui((state) => state.setGuiValue); + const updateGuiProps = viewer.useGui((state) => state.updateGuiProps); return void; addModal: (config: Messages.GuiModalMessage) => void; removeModal: (id: string) => void; - setGuiValue: (id: string, value: any) => void; - setGuiVisible: (id: string, visible: boolean) => void; - setGuiDisabled: (id: string, visible: boolean) => void; + updateGuiProps: (id: string, changes: Messages.GuiComponentPropsPartial) => Messages.GuiAddComponentMessage; removeGui: (id: string) => void; resetGui: () => void; } @@ -92,7 +76,7 @@ export function computeRelativeLuminance(color: string) { export function useGuiState(initialServer: string) { return React.useState(() => create( - immer((set) => ({ + immer((set, get) => ({ ...cleanGuiState, server: initialServer, setTheme: (theme) => @@ -122,21 +106,6 @@ export function useGuiState(initialServer: string) { set((state) => { state.modals = state.modals.filter((m) => m.id !== id); }), - setGuiValue: (id, value) => - set((state) => { - const config = state.guiConfigFromId[id] as any; - state.guiConfigFromId[id] = {...config, value} as GuiConfig; - }), - setGuiVisible: (id, visible) => - set((state) => { - const config = state.guiConfigFromId[id] as any; - state.guiConfigFromId[id] = {...config, visible} as GuiConfig; - }), - setGuiDisabled: (id, disabled) => - set((state) => { - const config = state.guiConfigFromId[id] as any; - state.guiConfigFromId[id] = {...config, disabled} as GuiConfig; - }), removeGui: (id) => set((state) => { const guiConfig = state.guiConfigFromId[id]; @@ -152,6 +121,14 @@ export function useGuiState(initialServer: string) { state.guiOrderFromId = {}; state.guiConfigFromId = {}; }), + updateGuiProps: (id, changes) => { + set((state) => { + const config = state.guiConfigFromId[id]; + if (config === undefined) return; + state.guiConfigFromId[id] = {...config, ...changes} as GuiConfig; + }); + return get().guiConfigFromId[id]; + } })), ), )[0]; diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 0d60776c8..045a201cc 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -67,9 +67,7 @@ function useMessageHandler() { const addModal = viewer.useGui((state) => state.addModal); const removeModal = viewer.useGui((state) => state.removeModal); const removeGui = viewer.useGui((state) => state.removeGui); - const setGuiValue = viewer.useGui((state) => state.setGuiValue); - const setGuiVisible = viewer.useGui((state) => state.setGuiVisible); - const setGuiDisabled = viewer.useGui((state) => state.setGuiDisabled); + const updateGuiProps = viewer.useGui((state) => state.updateGuiProps); const setClickable = viewer.useSceneTree((state) => state.setClickable); // Same as addSceneNode, but make a parent in the form of a dummy coordinate @@ -746,19 +744,11 @@ function useMessageHandler() { viewer.backgroundMaterialRef.current!.uniforms.enabled.value = false; return; } - // Set the value of a GUI input. - case "GuiSetValueMessage": { - setGuiValue(message.id, message.value); - return; - } - // Set the hidden state of a GUI input. - case "GuiSetVisibleMessage": { - setGuiVisible(message.id, message.visible); - return; - } - // Set the disabled state of a GUI input. - case "GuiSetDisabledMessage": { - setGuiDisabled(message.id, message.disabled); + // Update props of a GUI component + case "GuiUpdateMessage": { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { id, type, component_type, ...changes } = message; + updateGuiProps(message.id, changes); return; } // Remove a GUI input. diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index 3f5409416..18b1904eb 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -592,42 +592,26 @@ export interface GuiRemoveMessage { type: "GuiRemoveMessage"; id: string; } -/** Sent client->server when a GUI input is changed. +/** Sent client<->server when a GUI component is changed. * * (automatically generated) */ -export interface GuiUpdateMessage { - type: "GuiUpdateMessage"; - id: string; - value: any; -} -/** Sent client->server when a GUI input is changed. - * - * (automatically generated) - */ -export interface GuiSetVisibleMessage { - type: "GuiSetVisibleMessage"; - id: string; - visible: boolean; -} -/** Sent client->server when a GUI input is changed. - * - * (automatically generated) - */ -export interface GuiSetDisabledMessage { - type: "GuiSetDisabledMessage"; - id: string; - disabled: boolean; -} -/** Sent server->client to set the value of a particular input. - * - * (automatically generated) - */ -export interface GuiSetValueMessage { - type: "GuiSetValueMessage"; + +type _GuiComponentPropsPartial = T extends T + ? Partial> + : never; +export type GuiComponentPropsPartial = + _GuiComponentPropsPartial; +export type _GuiComponentNames = T extends GuiAddComponentMessage + ? T["type"] + : never; +export type GuiComponentNames = _GuiComponentNames; +export type GuiUpdateMessage = { id: string; - value: any; -} + type: "GuiUpdateMessage"; + component_type: GuiComponentNames; +} & GuiComponentPropsPartial; + /** Message from server->client to configure parts of the GUI. * * (automatically generated) @@ -816,9 +800,6 @@ export type Message = | GuiCloseModalMessage | GuiRemoveMessage | GuiUpdateMessage - | GuiSetVisibleMessage - | GuiSetDisabledMessage - | GuiSetValueMessage | ThemeConfigurationMessage | CatmullRomSplineMessage | CubicBezierSplineMessage @@ -830,3 +811,18 @@ export type Message = | ShareUrlUpdated | ShareUrlDisconnect | SetGuiPanelLabelMessage; +export type GuiAddComponentMessage = + | GuiAddFolderMessage + | GuiAddMarkdownMessage + | GuiAddTabGroupMessage + | GuiAddButtonMessage + | GuiAddSliderMessage + | GuiAddNumberMessage + | GuiAddRgbMessage + | GuiAddRgbaMessage + | GuiAddCheckboxMessage + | GuiAddVector2Message + | GuiAddVector3Message + | GuiAddTextMessage + | GuiAddDropdownMessage + | GuiAddButtonGroupMessage; diff --git a/src/viser/client/src/components/Button.tsx b/src/viser/client/src/components/Button.tsx index 851582a0b..fe5cb607c 100644 --- a/src/viser/client/src/components/Button.tsx +++ b/src/viser/client/src/components/Button.tsx @@ -13,7 +13,7 @@ import { Button } from "@mantine/core"; import React from "react"; -export default function ButtonComponent({ id, visible, disabled, label, ...otherProps }: GuiAddButtonMessage) { +export default function ButtonComponent({ id, visible, disabled, label, type, ...otherProps }: GuiAddButtonMessage) { const { messageSender } = React.useContext(GuiComponentContext)!; const theme = useMantineTheme(); const { color, icon_base64 } = otherProps; @@ -32,6 +32,7 @@ export default function ButtonComponent({ id, visible, disabled, label, ...other onClick={() => messageSender({ type: "GuiUpdateMessage", + component_type: type, id: id, value: true, }) diff --git a/src/viser/client/src/components/ButtonGroup.tsx b/src/viser/client/src/components/ButtonGroup.tsx index ce747ef2f..fd0061956 100644 --- a/src/viser/client/src/components/ButtonGroup.tsx +++ b/src/viser/client/src/components/ButtonGroup.tsx @@ -7,6 +7,7 @@ import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; export default function ButtonGroupComponent({ id, hint, + type, label, visible, disabled, @@ -22,6 +23,7 @@ export default function ButtonGroupComponent({ onClick={() => messageSender({ type: "GuiUpdateMessage", + component_type: type, id: id, value: option, }) diff --git a/src/viser/infra/_messages.py b/src/viser/infra/_messages.py index ce9d2af20..08693b98b 100644 --- a/src/viser/infra/_messages.py +++ b/src/viser/infra/_messages.py @@ -77,24 +77,19 @@ class Message(abc.ABC): def as_serializable_dict(self) -> Dict[str, Any]: """Convert a Python Message object into bytes.""" - hints = get_type_hints_cached(type(self)) + message_type = type(self) + hints = get_type_hints_cached(message_type) out = { k: _prepare_for_serialization(v, hints[k]) for k, v in vars(self).items() } - out["type"] = type(self).__name__ + out["type"] = message_type.__name__ return out @classmethod - def deserialize(cls, message: bytes) -> Message: - """Convert bytes into a Python Message object.""" - mapping = msgpack.unpackb(message) + def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]: + """Convert a dict message back into a Python Message object.""" - # msgpack deserializes to lists by default, but all of our annotations use - # tuples. - mapping = { - k: tuple(v) if isinstance(v, list) else v for k, v in mapping.items() - } - message_type = cls._subclass_from_type_string()[cast(str, mapping.pop("type"))] + hints = get_type_hints_cached(cls) # If annotated as a float but we got an integer, cast to float. These # are both `number` in Javascript. @@ -109,9 +104,22 @@ def coerce_floats(value: Any, annotation: Type[Any]) -> Any: else: return value - type_hints = get_type_hints(message_type) - mapping = {k: coerce_floats(v, type_hints[k]) for k, v in mapping.items()} - return message_type(**mapping) # type: ignore + mapping = {k: coerce_floats(v, hints[k]) for k, v in mapping.items()} + return mapping + + @classmethod + def deserialize(cls, message: bytes) -> Message: + """Convert bytes into a Python Message object.""" + mapping = msgpack.unpackb(message) + + # msgpack deserializes to lists by default, but all of our annotations use + # tuples. + mapping = { + k: tuple(v) if isinstance(v, list) else v for k, v in mapping.items() + } + message_type = cls._subclass_from_type_string()[cast(str, mapping.pop("type"))] + message_kwargs = message_type._from_serializable_dict(mapping) + return message_type(**message_kwargs) @classmethod @functools.lru_cache(maxsize=100) diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index 84ba23c23..2463b15a0 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -1,4 +1,5 @@ import dataclasses +from collections import defaultdict from typing import Any, ClassVar, Type, Union, cast, get_type_hints import numpy as onp @@ -74,6 +75,7 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: """Generate TypeScript definitions for all subclasses of a base message class.""" out_lines = [] message_types = message_cls.get_subclasses() + tag_map = defaultdict(list) # Generate interfaces for each specific message. for cls in message_types: @@ -86,6 +88,13 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: out_lines.append(" * (automatically generated)") out_lines.append(" */") + for tag in cls._tags or []: + tag_map[tag].append(cls.__name__) + + if hasattr(cls, "_get_ts_type"): + out_lines.append(cls._get_ts_type()) + continue + out_lines.append(f"export interface {cls.__name__} " + "{") out_lines.append(f' type: "{cls.__name__}";') field_names = set([f.name for f in dataclasses.fields(cls)]) # type: ignore @@ -106,6 +115,13 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: out_lines.append(f" | {cls.__name__}") out_lines[-1] = out_lines[-1] + ";" + # Generate union type over all tags. + for tag, cls_names in tag_map.items(): + out_lines.append(f"export type {tag} = ") + for cls in cls_names: + out_lines.append(f" | {cls}") + out_lines[-1] = out_lines[-1] + ";" + interfaces = "\n".join(out_lines) + "\n" # Add header and return. From cb3412eb51585846660b9892613917327f5a1de7 Mon Sep 17 00:00:00 2001 From: Jonas Kulhanek Date: Wed, 31 Jan 2024 16:31:41 +0100 Subject: [PATCH 03/10] Update message typecasting --- src/viser/_gui_api.py | 41 +++++++---- src/viser/_gui_handles.py | 4 +- src/viser/_messages.py | 9 ++- src/viser/infra/_messages.py | 77 +++++++++++++------- src/viser/infra/_typescript_interface_gen.py | 8 +- 5 files changed, 90 insertions(+), 49 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index f605ad0be..6b51df724 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -131,24 +131,35 @@ def _handle_gui_updates( handle_state = handle._impl - # Do some type casting. This is necessary when we expect floats but the - # Javascript side gives us integers. - if handle_state.typ is tuple: - assert len(message.value) == len(handle_state.value) - value = tuple( - type(handle_state.value[i])(message.value[i]) - for i in range(len(message.value)) - ) - else: - value = handle_state.typ(message.value) + has_changed = False + changes = message.changes() + for k, v in changes.items(): + current_value = getattr(handle_state, k, None) + if current_value != v: + has_changed = True + + if "value" in changes: + # Do some type casting. This is necessary when we expect floats but the + # Javascript side gives us integers. + value = changes["value"] + if handle_state.typ is tuple: + assert len(value) == len(handle_state.value) + value = tuple( + type(handle_state.value[i])(message.value[i]) + for i in range(len(value)) + ) + else: + value = handle_state.typ(value) + changes["value"] = value # Only call update when value has actually changed. - if not handle_state.is_button and value == handle_state.value: + if not handle_state.is_button and not has_changed: return # Update state. with self._get_api()._atomic_lock: - handle_state.value = value + for k, v in changes.items(): + setattr(handle_state, k, v) handle_state.update_timestamp = time.time() # Trigger callbacks. @@ -166,7 +177,7 @@ def _handle_gui_updates( cb(GuiEvent(client, client_id, handle)) if handle_state.sync_cb is not None: - handle_state.sync_cb(client_id, value) + handle_state.sync_cb(client_id, changes) def _get_container_id(self) -> str: """Get container ID associated with the current thread.""" @@ -981,8 +992,8 @@ def _create_gui_input( # This will be a no-op for client handles. if not is_button: - def sync_other_clients(client_id: ClientId, value: Any) -> None: - message = _messages.GuiUpdateMessage(handle_state.id, handle_state.message_type, value=value) + def sync_other_clients(client_id: ClientId, changes) -> None: + message = _messages.GuiUpdateMessage(handle_state.id, handle_state.message_type, **changes) message.excluded_self_client = client_id self._get_api()._queue(message) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index ff6219977..5435df7b9 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -20,7 +20,7 @@ Type, TypeVar, Union, - ClassVar, + Any, ) import imageio.v3 as iio @@ -84,7 +84,7 @@ class _GuiHandleState(Generic[T]): is_button: bool """Indicates a button element, which requires special handling.""" - sync_cb: Optional[Callable[[ClientId, T], None]] + sync_cb: Optional[Callable[[ClientId, Dict[str, Any]], None]] """Callback for synchronizing inputs across clients.""" disabled: bool diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 6112b2bc1..e04052291 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -4,7 +4,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Optional, Tuple, Union, TypeVar, Callable, Type, Dict, cast +from typing import Any, Optional, Tuple, Union, TypeVar, Callable, Type, Dict, cast, ClassVar import numpy as onp import numpy.typing as onpt @@ -15,7 +15,7 @@ class Message(infra.Message): - _tags: Tuple[str, ...] = tuple() + _tags: ClassVar[Tuple[str, ...]] = tuple() @override @@ -41,7 +41,7 @@ def redundancy_key(self) -> str: return "_".join(parts) -T = TypeVar("T", bound=Message) +T = TypeVar("T", bound=Type[Message]) def tag_class(tag: str) -> Callable[[T], T]: @@ -545,6 +545,9 @@ def __init__(self, id: str, type: Type[Message], **changes): for k, v in changes.items(): setattr(self, k, v) + def changes(self): + return {k: v for k, v in vars(self).items() if k not in {"id", "_type"}} + def as_serializable_dict(self) -> Dict[str, Any]: """Convert a Python Message object into bytes.""" from viser.infra._messages import get_type_hints_cached, _prepare_for_serialization diff --git a/src/viser/infra/_messages.py b/src/viser/infra/_messages.py index 08693b98b..2616ce88a 100644 --- a/src/viser/infra/_messages.py +++ b/src/viser/infra/_messages.py @@ -18,6 +18,38 @@ ClientId = Any +def _prepare_for_deserialization(value: Any, annotation: Type) -> Any: + # If annotated as a float but we got an integer, cast to float. These + # are both `number` in Javascript. + if annotation is float: + return float(value) + elif annotation is int: + return int(value) + elif get_origin(annotation) is tuple: + out = [] + args = get_args(annotation) + if ... in args: + if len(value) < len(args) - 1: + warnings.warn(f"[viser] {value} does not match annotation {annotation}") + return value + ellipsis_index = args.index(...) + num_ellipsis = len(value) - len(args) + 2 + args = args[:(ellipsis_index - 1)] + tuple(args[ellipsis_index - 1] for _ in range(num_ellipsis)) + args[ellipsis_index + 1 :] + + if len(value) != len(args): + warnings.warn(f"[viser] {value} does not match annotation {annotation}") + return value + + for i, v in enumerate(value): + out.append( + # Hack to be OK with wrong type annotations. + # https://github.com/nerfstudio-project/nerfstudio/pull/1805 + _prepare_for_deserialization(v, args[i]) if i < len(args) else v + ) + return tuple(out) + return value + + def _prepare_for_serialization(value: Any, annotation: Type) -> Any: """Prepare any special types for serialization.""" @@ -38,19 +70,25 @@ def _prepare_for_serialization(value: Any, annotation: Type) -> Any: out = [] args = get_args(annotation) - if len(args) >= 1: - if len(args) >= 2 and args[1] == ...: - args = (args[0],) * len(value) - elif len(value) != len(args): + if ... in args: + if len(value) < len(args) - 1: warnings.warn(f"[viser] {value} does not match annotation {annotation}") - - for i, v in enumerate(value): - out.append( - # Hack to be OK with wrong type annotations. - # https://github.com/nerfstudio-project/nerfstudio/pull/1805 - _prepare_for_serialization(v, args[i]) if i < len(args) else v - ) - return tuple(out) + return value + ellipsis_index = args.index(...) + num_ellipsis = len(value) - len(args) + 2 + args = args[:(ellipsis_index - 1)] + tuple(args[ellipsis_index - 1] for _ in range(num_ellipsis)) + args[ellipsis_index + 1 :] + + if len(value) != len(args): + warnings.warn(f"[viser] {value} does not match annotation {annotation}") + return value + + for i, v in enumerate(value): + out.append( + # Hack to be OK with wrong type annotations. + # https://github.com/nerfstudio-project/nerfstudio/pull/1805 + _prepare_for_serialization(v, args[i]) if i < len(args) else v + ) + return tuple(out) # For arrays, we serialize underlying data directly. The client is responsible for # reading using the correct dtype. @@ -91,20 +129,7 @@ def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]: hints = get_type_hints_cached(cls) - # If annotated as a float but we got an integer, cast to float. These - # are both `number` in Javascript. - def coerce_floats(value: Any, annotation: Type[Any]) -> Any: - if annotation is float: - return float(value) - elif get_origin(annotation) is tuple: - return tuple( - coerce_floats(value[i], typ) - for i, typ in enumerate(get_args(annotation)) - ) - else: - return value - - mapping = {k: coerce_floats(v, hints[k]) for k, v in mapping.items()} + mapping = {k: _prepare_for_deserialization(v, hints[k]) for k, v in mapping.items()} return mapping @classmethod diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index 2463b15a0..f82e7fa46 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -88,11 +88,13 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: out_lines.append(" * (automatically generated)") out_lines.append(" */") - for tag in cls._tags or []: + for tag in getattr(cls, "_tags", []): tag_map[tag].append(cls.__name__) - if hasattr(cls, "_get_ts_type"): - out_lines.append(cls._get_ts_type()) + get_ts_type = getattr(cls, "_get_ts_type", None) + if get_ts_type is not None: + assert callable(get_ts_type) + out_lines.append(get_ts_type()) continue out_lines.append(f"export interface {cls.__name__} " + "{") From 2803aa89581f6f67d9df69f224b669464389d46a Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Mon, 5 Feb 2024 23:30:28 -0800 Subject: [PATCH 04/10] Embrace untyped GuiUpdateMessage? --- src/viser/_gui_api.py | 43 ++++++------ src/viser/_gui_handles.py | 38 +++++----- src/viser/_messages.py | 70 ++++++------------- .../client/src/ControlPanel/Generated.tsx | 57 +++++++++------ .../src/ControlPanel/GuiComponentContext.tsx | 8 +-- .../client/src/ControlPanel/GuiState.tsx | 19 +++-- src/viser/client/src/WebsocketInterface.tsx | 8 +-- src/viser/client/src/WebsocketMessages.tsx | 22 ++---- src/viser/client/src/components/Button.tsx | 25 ++++--- .../client/src/components/ButtonGroup.tsx | 53 +++++++------- src/viser/client/src/components/Checkbox.tsx | 67 ++++++++++-------- src/viser/client/src/components/Dropdown.tsx | 65 +++++++++-------- src/viser/client/src/components/Folder.tsx | 17 +++-- src/viser/client/src/components/Markdown.tsx | 8 ++- .../client/src/components/NumberInput.tsx | 63 ++++++++++------- src/viser/client/src/components/Rgb.tsx | 47 ++++++++----- src/viser/client/src/components/Rgba.tsx | 45 +++++++----- src/viser/client/src/components/Slider.tsx | 34 +++++---- src/viser/client/src/components/TabGroup.tsx | 7 +- src/viser/client/src/components/TextInput.tsx | 38 +++++----- src/viser/client/src/components/Vector2.tsx | 40 +++++++---- src/viser/client/src/components/Vector3.tsx | 40 +++++++---- src/viser/client/src/components/common.tsx | 24 ++++--- src/viser/client/src/components/utils.tsx | 9 ++- 24 files changed, 456 insertions(+), 391 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 6b51df724..235c882f2 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -129,38 +129,34 @@ def _handle_gui_updates( if handle is None: return + prop_name = message.prop_name + prop_value = message.prop_value + del message + handle_state = handle._impl + assert hasattr(handle_state, prop_name) + current_value = getattr(handle_state, prop_name) - has_changed = False - changes = message.changes() - for k, v in changes.items(): - current_value = getattr(handle_state, k, None) - if current_value != v: - has_changed = True + has_changed = current_value != prop_value - if "value" in changes: + if prop_name == "value": # Do some type casting. This is necessary when we expect floats but the # Javascript side gives us integers. - value = changes["value"] if handle_state.typ is tuple: - assert len(value) == len(handle_state.value) - value = tuple( - type(handle_state.value[i])(message.value[i]) - for i in range(len(value)) + assert len(prop_value) == len(handle_state.value) + prop_value = tuple( + type(handle_state.value[i])(prop_value[i]) + for i in range(len(prop_value)) ) else: - value = handle_state.typ(value) - changes["value"] = value + prop_value = handle_state.typ(prop_value) # Only call update when value has actually changed. if not handle_state.is_button and not has_changed: return # Update state. - with self._get_api()._atomic_lock: - for k, v in changes.items(): - setattr(handle_state, k, v) - handle_state.update_timestamp = time.time() + setattr(handle_state, prop_name, prop_value) # Trigger callbacks. for cb in handle_state.update_cb: @@ -176,8 +172,9 @@ def _handle_gui_updates( assert False cb(GuiEvent(client, client_id, handle)) + if handle_state.sync_cb is not None: - handle_state.sync_cb(client_id, changes) + handle_state.sync_cb(client_id, prop_name, prop_value) def _get_container_id(self) -> str: """Get container ID associated with the current thread.""" @@ -992,8 +989,12 @@ def _create_gui_input( # This will be a no-op for client handles. if not is_button: - def sync_other_clients(client_id: ClientId, changes) -> None: - message = _messages.GuiUpdateMessage(handle_state.id, handle_state.message_type, **changes) + def sync_other_clients( + client_id: ClientId, prop_name: str, prop_value: Any + ) -> None: + message = _messages.GuiUpdateMessage( + handle_state.id, prop_name, prop_value + ) message.excluded_self_client = client_id self._get_api()._queue(message) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 5435df7b9..2857b622f 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Any, Callable, Dict, Generic, @@ -20,7 +21,6 @@ Type, TypeVar, Union, - Any, ) import imageio.v3 as iio @@ -31,11 +31,11 @@ from ._icons_enum import IconName from ._message_api import _encode_image_base64 from ._messages import ( + GuiAddMarkdownMessage, GuiAddTabGroupMessage, GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, - GuiAddMarkdownMessage, Message, ) from .infra import ClientId @@ -84,7 +84,7 @@ class _GuiHandleState(Generic[T]): is_button: bool """Indicates a button element, which requires special handling.""" - sync_cb: Optional[Callable[[ClientId, Dict[str, Any]], None]] + sync_cb: Optional[Callable[[ClientId, str, Any], None]] """Callback for synchronizing inputs across clients.""" disabled: bool @@ -139,7 +139,7 @@ def value(self, value: Union[T, onp.ndarray]) -> None: # Send to client, except for buttons. if not self._impl.is_button: self._impl.gui_api._get_api()._queue( - GuiUpdateMessage(self._impl.id, self._impl.message_type, value=value) # type: ignore + GuiUpdateMessage(self._impl.id, "value", value) ) # Set internal state. We automatically convert numpy arrays to the expected @@ -178,7 +178,7 @@ def disabled(self, disabled: bool) -> None: return self._impl.gui_api._get_api()._queue( - GuiUpdateMessage(self._impl.id, self._impl.message_type, disabled=disabled) + GuiUpdateMessage(self._impl.id, "disabled", disabled) ) self._impl.disabled = disabled @@ -194,7 +194,7 @@ def visible(self, visible: bool) -> None: return self._impl.gui_api._get_api()._queue( - GuiUpdateMessage(self._impl.id, self._impl.message_type, visible=visible) + GuiUpdateMessage(self._impl.id, "visible", visible) ) self._impl.visible = visible @@ -314,11 +314,7 @@ def options(self, options: Iterable[StringType]) -> None: self._impl.initial_value = self._impl_options[0] self._impl.gui_api._get_api()._queue( - GuiUpdateMessage( - self._impl.id, - self._impl.message_type, - options=self._impl_options, - ) + GuiUpdateMessage(self._impl.id, "options", self._impl_options) ) if self.value not in self._impl_options: @@ -364,13 +360,19 @@ def remove(self) -> None: def _sync_with_client(self) -> None: """Send a message that syncs tab state with the client.""" + self._gui_api._get_api()._queue( + GuiUpdateMessage(self._tab_group_id, "tab_labels", tuple(self._labels)) + ) + self._gui_api._get_api()._queue( + GuiUpdateMessage( + self._tab_group_id, "tab_icons_base64", tuple(self._icons_base64) + ) + ) self._gui_api._get_api()._queue( GuiUpdateMessage( self._tab_group_id, - GuiAddTabGroupMessage, - tab_labels=tuple(self._labels), - tab_icons_base64=tuple(self._icons_base64), - tab_container_ids=tuple(tab._id for tab in self._tabs), + "tab_container_ids", + tuple(tab._id for tab in self._tabs), ) ) @@ -561,8 +563,8 @@ def content(self, content: str) -> None: self._gui_api._get_api()._queue( GuiUpdateMessage( self._id, - GuiAddMarkdownMessage, - markdown=_parse_markdown(content, self._image_root), + "markdown", + _parse_markdown(content, self._image_root), ) ) @@ -582,7 +584,7 @@ def visible(self, visible: bool) -> None: if visible == self.visible: return - self._gui_api._get_api()._queue(GuiUpdateMessage(self._id, GuiAddMarkdownMessage, visible=visible)) + self._gui_api._get_api()._queue(GuiUpdateMessage(self._id, "visible", visible)) self._visible = visible def __post_init__(self) -> None: diff --git a/src/viser/_messages.py b/src/viser/_messages.py index e04052291..61fab99ed 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -4,12 +4,23 @@ from __future__ import annotations import dataclasses -from typing import Any, Optional, Tuple, Union, TypeVar, Callable, Type, Dict, cast, ClassVar +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) +import msgpack import numpy as onp import numpy.typing as onpt from typing_extensions import Literal, override -import msgpack from . import infra, theme @@ -17,7 +28,6 @@ class Message(infra.Message): _tags: ClassVar[Tuple[str, ...]] = tuple() - @override def redundancy_key(self) -> str: """Returns a unique key for this message, used for detecting redundant @@ -54,8 +64,6 @@ def wrapper(cls: T) -> T: return wrapper - - @dataclasses.dataclass class ViewerCameraMessage(Message): """Message for a posed viewer camera. @@ -536,51 +544,17 @@ class GuiRemoveMessage(Message): id: str +@dataclasses.dataclass class GuiUpdateMessage(Message): - """Sent client<->server when a GUI component is changed.""" - - def __init__(self, id: str, type: Type[Message], **changes): - self.id = id - self._type = type - for k, v in changes.items(): - setattr(self, k, v) - - def changes(self): - return {k: v for k, v in vars(self).items() if k not in {"id", "_type"}} - - def as_serializable_dict(self) -> Dict[str, Any]: - """Convert a Python Message object into bytes.""" - from viser.infra._messages import get_type_hints_cached, _prepare_for_serialization - hints = get_type_hints_cached(self._type) - mapping = { - k: _prepare_for_serialization(v, hints[k]) for k, v in vars(self).items() if k != "_type" - } - mapping["component_type"] = self._type.__name__ - mapping["type"] = type(self).__name__ - return mapping - - @classmethod - def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]: - mapping["type"] = mapping.pop("component_type") - message_type = Message._subclass_from_type_string()[cast(str, mapping.pop("type"))] - kwargs = message_type._from_serializable_dict(mapping) - kwargs["type"] = message_type - return kwargs - - @classmethod - def _get_ts_type(cls): - return f''' -type _GuiComponentPropsPartial = T extends T ? Partial>: never; -export type GuiComponentPropsPartial = _GuiComponentPropsPartial; -export type _GuiComponentNames = T extends GuiAddComponentMessage ? T["type"] : never; -export type GuiComponentNames = _GuiComponentNames; -export type {cls.__name__} = {{ - id: string; - type: "{cls.__name__}"; - component_type: GuiComponentNames; -}} & GuiComponentPropsPartial; -''' + """Sent client<->server when any property of a GUI component is changed.""" + + id: str + prop_name: str + prop_value: Any + @override + def redundancy_key(self) -> str: + return type(self).__name__ + "-" + self.id + "-" + self.prop_name @dataclasses.dataclass diff --git a/src/viser/client/src/ControlPanel/Generated.tsx b/src/viser/client/src/ControlPanel/Generated.tsx index b430dfe38..55ed0e149 100644 --- a/src/viser/client/src/ControlPanel/Generated.tsx +++ b/src/viser/client/src/ControlPanel/Generated.tsx @@ -3,9 +3,7 @@ import { makeThrottledMessageSender } from "../WebsocketFunctions"; import { GuiConfig } from "./GuiState"; import { GuiComponentContext } from "./GuiComponentContext"; -import { - Box, -} from "@mantine/core"; +import { Box } from "@mantine/core"; import React from "react"; import ButtonComponent from "../components/Button"; import SliderComponent from "../components/Slider"; @@ -22,7 +20,6 @@ import MarkdownComponent from "../components/Markdown"; import TabGroupComponent from "../components/TabGroup"; import FolderComponent from "../components/Folder"; - function GuiContainer({ containerId }: { containerId: string }) { const viewer = React.useContext(ViewerContext)!; @@ -40,7 +37,9 @@ function GuiContainer({ containerId }: { containerId: string }) { })); let pb = undefined; guiIdOrderPairArray = guiIdOrderPairArray.sort((a, b) => a.order - b.order); - const inputProps = viewer.useGui((state) => guiIdOrderPairArray.map(pair => state.guiConfigFromId[pair.id])); + const inputProps = viewer.useGui((state) => + guiIdOrderPairArray.map((pair) => state.guiConfigFromId[pair.id]), + ); const lastProps = inputProps && inputProps[inputProps.length - 1]; // Done to match the old behaviour. Is it still needed? @@ -49,31 +48,45 @@ function GuiContainer({ containerId }: { containerId: string }) { } const out = ( - {inputProps.map((conf) => )} + {inputProps.map((conf) => ( + + ))} ); return out; } /** Root of generated inputs. */ -export default function GeneratedGuiContainer({ containerId }: { containerId: string; }) { +export default function GeneratedGuiContainer({ + containerId, +}: { + containerId: string; +}) { const viewer = React.useContext(ViewerContext)!; + const updateGuiProps = viewer.useGui((state) => state.updateGuiProps); const messageSender = makeThrottledMessageSender(viewer.websocketRef, 50); + function setValue(id: string, value: any) { - const { type } = updateGuiProps(id, { value }); - messageSender({ type: "GuiUpdateMessage", component_type: type, id, value }); + updateGuiProps(id, "value", value); + messageSender({ + type: "GuiUpdateMessage", + id: id, + prop_name: "value", + prop_value: value, + }); } - - const updateGuiProps = viewer.useGui((state) => state.updateGuiProps); - return - - - + return ( + + + + ); } /** A single generated GUI element. */ @@ -92,7 +105,7 @@ function GeneratedInput(conf: GuiConfig) { case "GuiAddNumberMessage": return ; case "GuiAddTextMessage": - return ; + return ; case "GuiAddCheckboxMessage": return ; case "GuiAddVector2Message": @@ -114,4 +127,4 @@ function GeneratedInput(conf: GuiConfig) { function assertNeverType(x: never): never { throw new Error("Unexpected object: " + (x as any).type); -} \ No newline at end of file +} diff --git a/src/viser/client/src/ControlPanel/GuiComponentContext.tsx b/src/viser/client/src/ControlPanel/GuiComponentContext.tsx index f566efc1c..91948101c 100644 --- a/src/viser/client/src/ControlPanel/GuiComponentContext.tsx +++ b/src/viser/client/src/ControlPanel/GuiComponentContext.tsx @@ -2,10 +2,10 @@ import * as React from "react"; import * as Messages from "../WebsocketMessages"; interface GuiComponentContext { - folderDepth: number, - setValue: (id: string, value: any) => void, - messageSender: (message: Messages.Message) => void, - GuiContainer: React.FC<{ containerId: string }>, + folderDepth: number; + setValue: (id: string, value: any) => void; + messageSender: (message: Messages.Message) => void; + GuiContainer: React.FC<{ containerId: string }>; } export const GuiComponentContext = React.createContext({ diff --git a/src/viser/client/src/ControlPanel/GuiState.tsx b/src/viser/client/src/ControlPanel/GuiState.tsx index 2a513b57a..ea15f8c93 100644 --- a/src/viser/client/src/ControlPanel/GuiState.tsx +++ b/src/viser/client/src/ControlPanel/GuiState.tsx @@ -34,7 +34,7 @@ interface GuiActions { addGui: (config: GuiConfig) => void; addModal: (config: Messages.GuiModalMessage) => void; removeModal: (id: string) => void; - updateGuiProps: (id: string, changes: Messages.GuiComponentPropsPartial) => Messages.GuiAddComponentMessage; + updateGuiProps: (id: string, prop_name: string, prop_value: any) => void; removeGui: (id: string) => void; resetGui: () => void; } @@ -76,7 +76,7 @@ export function computeRelativeLuminance(color: string) { export function useGuiState(initialServer: string) { return React.useState(() => create( - immer((set, get) => ({ + immer((set) => ({ ...cleanGuiState, server: initialServer, setTheme: (theme) => @@ -121,14 +121,19 @@ export function useGuiState(initialServer: string) { state.guiOrderFromId = {}; state.guiConfigFromId = {}; }), - updateGuiProps: (id, changes) => { + updateGuiProps: (id, name, value) => { set((state) => { const config = state.guiConfigFromId[id]; - if (config === undefined) return; - state.guiConfigFromId[id] = {...config, ...changes} as GuiConfig; + if (config === undefined) { + console.error("Tried to update non-existent component", id); + return; + } + state.guiConfigFromId[id] = { + ...config, + [name]: value, + } as GuiConfig; }); - return get().guiConfigFromId[id]; - } + }, })), ), )[0]; diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 045a201cc..faece469a 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -665,7 +665,9 @@ function useMessageHandler() { }} > - + @@ -746,9 +748,7 @@ function useMessageHandler() { } // Update props of a GUI component case "GuiUpdateMessage": { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { id, type, component_type, ...changes } = message; - updateGuiProps(message.id, changes); + updateGuiProps(message.id, message.prop_name, message.prop_value); return; } // Remove a GUI input. diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index 18b1904eb..c014f43d9 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -592,26 +592,16 @@ export interface GuiRemoveMessage { type: "GuiRemoveMessage"; id: string; } -/** Sent client<->server when a GUI component is changed. +/** Sent client<->server when any property of a GUI component is changed. * * (automatically generated) */ - -type _GuiComponentPropsPartial = T extends T - ? Partial> - : never; -export type GuiComponentPropsPartial = - _GuiComponentPropsPartial; -export type _GuiComponentNames = T extends GuiAddComponentMessage - ? T["type"] - : never; -export type GuiComponentNames = _GuiComponentNames; -export type GuiUpdateMessage = { - id: string; +export interface GuiUpdateMessage { type: "GuiUpdateMessage"; - component_type: GuiComponentNames; -} & GuiComponentPropsPartial; - + id: string; + prop_name: string; + prop_value: any; +} /** Message from server->client to configure parts of the GUI. * * (automatically generated) diff --git a/src/viser/client/src/components/Button.tsx b/src/viser/client/src/components/Button.tsx index fe5cb607c..fb8803917 100644 --- a/src/viser/client/src/components/Button.tsx +++ b/src/viser/client/src/components/Button.tsx @@ -1,19 +1,18 @@ -import { - GuiAddButtonMessage, -} from "../WebsocketMessages"; +import { GuiAddButtonMessage } from "../WebsocketMessages"; import { computeRelativeLuminance } from "../ControlPanel/GuiState"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; -import { - Box, - Image, - useMantineTheme, -} from "@mantine/core"; +import { Box, Image, useMantineTheme } from "@mantine/core"; import { Button } from "@mantine/core"; import React from "react"; - -export default function ButtonComponent({ id, visible, disabled, label, type, ...otherProps }: GuiAddButtonMessage) { +export default function ButtonComponent({ + id, + visible, + disabled, + label, + ...otherProps +}: GuiAddButtonMessage) { const { messageSender } = React.useContext(GuiComponentContext)!; const theme = useMantineTheme(); const { color, icon_base64 } = otherProps; @@ -32,9 +31,9 @@ export default function ButtonComponent({ id, visible, disabled, label, type, .. onClick={() => messageSender({ type: "GuiUpdateMessage", - component_type: type, id: id, - value: true, + prop_name: "value", + prop_value: true, }) } style={{ height: "2.125em" }} @@ -67,4 +66,4 @@ export default function ButtonComponent({ id, visible, disabled, label, type, .. ); -} \ No newline at end of file +} diff --git a/src/viser/client/src/components/ButtonGroup.tsx b/src/viser/client/src/components/ButtonGroup.tsx index fd0061956..1fc1795ab 100644 --- a/src/viser/client/src/components/ButtonGroup.tsx +++ b/src/viser/client/src/components/ButtonGroup.tsx @@ -7,7 +7,6 @@ import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; export default function ButtonGroupComponent({ id, hint, - type, label, visible, disabled, @@ -15,28 +14,30 @@ export default function ButtonGroupComponent({ }: GuiAddButtonGroupMessage) { const { messageSender } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; - return - - {options.map((option, index) => ( - - messageSender({ - type: "GuiUpdateMessage", - component_type: type, - id: id, - value: option, - }) - } - style={{ flexGrow: 1, width: 0 }} - disabled={disabled} - compact - size="xs" - variant="outline" - > - {option} - - ))} - - ; -} \ No newline at end of file + return ( + + + {options.map((option, index) => ( + + messageSender({ + type: "GuiUpdateMessage", + id: id, + prop_name: "value", + prop_value: option, + }) + } + style={{ flexGrow: 1, width: 0 }} + disabled={disabled} + compact + size="xs" + variant="outline" + > + {option} + + ))} + + + ); +} diff --git a/src/viser/client/src/components/Checkbox.tsx b/src/viser/client/src/components/Checkbox.tsx index 5229bc4cc..23c1fdfa3 100644 --- a/src/viser/client/src/components/Checkbox.tsx +++ b/src/viser/client/src/components/Checkbox.tsx @@ -5,7 +5,14 @@ import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { GuiAddCheckboxMessage } from "../WebsocketMessages"; import { Box, Checkbox, Tooltip, useMantineTheme } from "@mantine/core"; -export default function CheckboxComponent({ id, disabled, visible, hint, label, value }: GuiAddCheckboxMessage) { +export default function CheckboxComponent({ + id, + disabled, + visible, + hint, + label, + value, +}: GuiAddCheckboxMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; const theme = useMantineTheme(); @@ -13,37 +20,39 @@ export default function CheckboxComponent({ id, disabled, visible, hint, label, computeRelativeLuminance(theme.fn.primaryColor()) > 50.0 ? theme.colors.gray[9] : theme.white; - let input = { - setValue(id, value.target.checked); - }} - disabled={disabled} - styles={{ - icon: { - color: inputColor + " !important", - }, - }} - /> + let input = ( + { + setValue(id, value.target.checked); + }} + disabled={disabled} + styles={{ + icon: { + color: inputColor + " !important", + }, + }} + /> + ); if (hint !== null && hint !== undefined) { // For checkboxes, we want to make sure that the wrapper // doesn't expand to the full width of the parent. This will // de-center the tooltip. - input = - - {input} - - + input = ( + + {input} + + ); } return {input}; -} \ No newline at end of file +} diff --git a/src/viser/client/src/components/Dropdown.tsx b/src/viser/client/src/components/Dropdown.tsx index 121730b4c..72b488da5 100644 --- a/src/viser/client/src/components/Dropdown.tsx +++ b/src/viser/client/src/components/Dropdown.tsx @@ -4,33 +4,42 @@ import { ViserInputComponent } from "./common"; import { GuiAddDropdownMessage } from "../WebsocketMessages"; import { Select } from "@mantine/core"; - -export default function DropdownComponent({ id, hint, label, value, disabled, visible, options }: GuiAddDropdownMessage) { +export default function DropdownComponent({ + id, + hint, + label, + value, + disabled, + visible, + options, +}: GuiAddDropdownMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; - return - setValue(id, value)} - disabled={disabled} - searchable - maxDropdownHeight={400} - size="xs" - styles={{ - input: { - padding: "0.5em", - letterSpacing: "-0.5px", - minHeight: "1.625rem", - height: "1.625rem", - }, - }} - // zIndex of dropdown should be >modal zIndex. - // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. - zIndex={1000} - withinPortal - /> - ; -} \ No newline at end of file + return ( + + setValue(id, value)} + disabled={disabled} + searchable + maxDropdownHeight={400} + size="xs" + styles={{ + input: { + padding: "0.5em", + letterSpacing: "-0.5px", + minHeight: "1.625rem", + height: "1.625rem", + }, + }} + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + zIndex={1000} + withinPortal + /> + + ); +} diff --git a/src/viser/client/src/components/Folder.tsx b/src/viser/client/src/components/Folder.tsx index 3620f60ac..1b555927d 100644 --- a/src/viser/client/src/components/Folder.tsx +++ b/src/viser/client/src/components/Folder.tsx @@ -6,7 +6,6 @@ import { Box, Collapse, Paper } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { ViewerContext } from "../App"; - export default function FolderComponent({ id, label, @@ -15,9 +14,7 @@ export default function FolderComponent({ }: GuiAddFolderMessage) { const viewer = React.useContext(ViewerContext)!; const [opened, { toggle }] = useDisclosure(expand_by_default); - const guiIdSet = viewer.useGui( - (state) => state.guiIdSetFromContainerId[id], - ); + const guiIdSet = viewer.useGui((state) => state.guiIdSetFromContainerId[id]); const guiContext = React.useContext(GuiComponentContext)!; const isEmpty = guiIdSet === undefined || Object.keys(guiIdSet).length === 0; @@ -62,10 +59,12 @@ export default function FolderComponent({ /> - + @@ -74,4 +73,4 @@ export default function FolderComponent({ ); -} \ No newline at end of file +} diff --git a/src/viser/client/src/components/Markdown.tsx b/src/viser/client/src/components/Markdown.tsx index da1e79682..752937c16 100644 --- a/src/viser/client/src/components/Markdown.tsx +++ b/src/viser/client/src/components/Markdown.tsx @@ -3,8 +3,10 @@ import Markdown from "../Markdown"; import { ErrorBoundary } from "react-error-boundary"; import { GuiAddMarkdownMessage } from "../WebsocketMessages"; - -export default function MarkdownComponent({ visible, markdown }: GuiAddMarkdownMessage) { +export default function MarkdownComponent({ + visible, + markdown, +}: GuiAddMarkdownMessage) { if (!visible) return <>>; return ( @@ -15,4 +17,4 @@ export default function MarkdownComponent({ visible, markdown }: GuiAddMarkdownM ); -} \ No newline at end of file +} diff --git a/src/viser/client/src/components/NumberInput.tsx b/src/viser/client/src/components/NumberInput.tsx index 5056601b9..8063dade1 100644 --- a/src/viser/client/src/components/NumberInput.tsx +++ b/src/viser/client/src/components/NumberInput.tsx @@ -4,33 +4,42 @@ import { GuiAddNumberMessage } from "../WebsocketMessages"; import { ViserInputComponent } from "./common"; import { NumberInput } from "@mantine/core"; - -export default function NumberInputComponent({ visible, id, label, hint, value, disabled, ...otherProps }: GuiAddNumberMessage) { +export default function NumberInputComponent({ + visible, + id, + label, + hint, + value, + disabled, + ...otherProps +}: GuiAddNumberMessage) { const { setValue } = React.useContext(GuiComponentContext)!; const { precision, min, max, step } = otherProps; if (!visible) return <>>; - return - { - // Ignore empty values. - newValue !== "" && setValue(id, newValue); - }} - styles={{ - input: { - minHeight: "1.625rem", - height: "1.625rem", - }, - }} - disabled={disabled} - stepHoldDelay={500} - stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} - /> - ; -} \ No newline at end of file + return ( + + { + // Ignore empty values. + newValue !== "" && setValue(id, newValue); + }} + styles={{ + input: { + minHeight: "1.625rem", + height: "1.625rem", + }, + }} + disabled={disabled} + stepHoldDelay={500} + stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} + /> + + ); +} diff --git a/src/viser/client/src/components/Rgb.tsx b/src/viser/client/src/components/Rgb.tsx index f5c5fbbe6..fdc6a67c9 100644 --- a/src/viser/client/src/components/Rgb.tsx +++ b/src/viser/client/src/components/Rgb.tsx @@ -5,24 +5,33 @@ import { rgbToHex, hexToRgb } from "./utils"; import { ViserInputComponent } from "./common"; import { GuiAddRgbMessage } from "../WebsocketMessages"; -export default function RgbComponent({ id, label, hint, value, disabled, visible }: GuiAddRgbMessage) { +export default function RgbComponent({ + id, + label, + hint, + value, + disabled, + visible, +}: GuiAddRgbMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; - return - setValue(id, hexToRgb(v))} - format="hex" - // zIndex of dropdown should be >modal zIndex. - // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. - dropdownZIndex={1000} - withinPortal - styles={{ - input: { height: "1.625rem", minHeight: "1.625rem" }, - icon: { transform: "scale(0.8)" }, - }} - /> - ; -} \ No newline at end of file + return ( + + setValue(id, hexToRgb(v))} + format="hex" + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + dropdownZIndex={1000} + withinPortal + styles={{ + input: { height: "1.625rem", minHeight: "1.625rem" }, + icon: { transform: "scale(0.8)" }, + }} + /> + + ); +} diff --git a/src/viser/client/src/components/Rgba.tsx b/src/viser/client/src/components/Rgba.tsx index 755a5e51c..b96491b90 100644 --- a/src/viser/client/src/components/Rgba.tsx +++ b/src/viser/client/src/components/Rgba.tsx @@ -5,23 +5,32 @@ import { rgbaToHex, hexToRgba } from "./utils"; import { ViserInputComponent } from "./common"; import { GuiAddRgbaMessage } from "../WebsocketMessages"; -export default function RgbaComponent({ id, label, hint, value, disabled, visible }: GuiAddRgbaMessage) { +export default function RgbaComponent({ + id, + label, + hint, + value, + disabled, + visible, +}: GuiAddRgbaMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; - return - setValue(id, hexToRgba(v))} - format="hexa" - // zIndex of dropdown should be >modal zIndex. - // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. - dropdownZIndex={1000} - withinPortal - styles={{ - input: { height: "1.625rem", minHeight: "1.625rem" }, - }} - /> - ; -} \ No newline at end of file + return ( + + setValue(id, hexToRgba(v))} + format="hexa" + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + dropdownZIndex={1000} + withinPortal + styles={{ + input: { height: "1.625rem", minHeight: "1.625rem" }, + }} + /> + + ); +} diff --git a/src/viser/client/src/components/Slider.tsx b/src/viser/client/src/components/Slider.tsx index 02bf300c3..d0c85a78b 100644 --- a/src/viser/client/src/components/Slider.tsx +++ b/src/viser/client/src/components/Slider.tsx @@ -1,18 +1,18 @@ import React from "react"; import { GuiAddSliderMessage } from "../WebsocketMessages"; -import { - Slider, - Box, - Flex, - Text, - NumberInput, -} from "@mantine/core"; +import { Slider, Box, Flex, Text, NumberInput } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { ViserInputComponent } from "./common"; - - -export default function SliderComponent({ id, label, hint, visible, disabled, value, ...otherProps }: GuiAddSliderMessage) { +export default function SliderComponent({ + id, + label, + hint, + visible, + disabled, + value, + ...otherProps +}: GuiAddSliderMessage) { const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; const updateValue = (value: number) => setValue(id, value); @@ -85,7 +85,13 @@ export default function SliderComponent({ id, label, hint, visible, disabled, va const containerProps = {}; // if (marks?.some(x => x.label)) // containerProps = { ...containerProps, "mb": "md" }; - - input = {input} - return {input}; -} \ No newline at end of file + + input = ( + + {input} + + ); + return ( + {input} + ); +} diff --git a/src/viser/client/src/components/TabGroup.tsx b/src/viser/client/src/components/TabGroup.tsx index 11f21a3f5..cd9aeeb57 100644 --- a/src/viser/client/src/components/TabGroup.tsx +++ b/src/viser/client/src/components/TabGroup.tsx @@ -4,10 +4,9 @@ import { Tabs, TabsValue } from "@mantine/core"; import { Image } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; - -export default function TabGroupComponent({ +export default function TabGroupComponent({ tab_labels, - tab_icons_base64, + tab_icons_base64, tab_container_ids, visible, }: GuiAddTabGroupMessage) { @@ -53,4 +52,4 @@ export default function TabGroupComponent({ ))} ); -} \ No newline at end of file +} diff --git a/src/viser/client/src/components/TextInput.tsx b/src/viser/client/src/components/TextInput.tsx index 671354162..1d4002b02 100644 --- a/src/viser/client/src/components/TextInput.tsx +++ b/src/viser/client/src/components/TextInput.tsx @@ -8,21 +8,23 @@ export default function TextInputComponent(props: GuiAddTextMessage) { const { id, hint, label, value, disabled, visible } = props; const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; - return - { - setValue(id, value.target.value); - }} - styles={{ - input: { - minHeight: "1.625rem", - height: "1.625rem", - padding: "0 0.5em", - }, - }} - disabled={disabled} - /> - ; -} \ No newline at end of file + return ( + + { + setValue(id, value.target.value); + }} + styles={{ + input: { + minHeight: "1.625rem", + height: "1.625rem", + padding: "0 0.5em", + }, + }} + disabled={disabled} + /> + + ); +} diff --git a/src/viser/client/src/components/Vector2.tsx b/src/viser/client/src/components/Vector2.tsx index 1c98276c6..089d0dc4d 100644 --- a/src/viser/client/src/components/Vector2.tsx +++ b/src/viser/client/src/components/Vector2.tsx @@ -3,21 +3,31 @@ import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { GuiAddVector2Message } from "../WebsocketMessages"; import { VectorInput, ViserInputComponent } from "./common"; -export default function Vector2Component({ id, hint, label, visible, disabled, value, ...otherProps }: GuiAddVector2Message) { +export default function Vector2Component({ + id, + hint, + label, + visible, + disabled, + value, + ...otherProps +}: GuiAddVector2Message) { const { min, max, step, precision } = otherProps; const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; - return - setValue(id, value)} - min={min} - max={max} - step={step} - precision={precision} - disabled={disabled} - /> - ; -} \ No newline at end of file + return ( + + setValue(id, value)} + min={min} + max={max} + step={step} + precision={precision} + disabled={disabled} + /> + + ); +} diff --git a/src/viser/client/src/components/Vector3.tsx b/src/viser/client/src/components/Vector3.tsx index 42cb569df..4b20219f8 100644 --- a/src/viser/client/src/components/Vector3.tsx +++ b/src/viser/client/src/components/Vector3.tsx @@ -3,21 +3,31 @@ import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { GuiAddVector3Message } from "../WebsocketMessages"; import { VectorInput, ViserInputComponent } from "./common"; -export default function Vector3Component({ id, hint, label, visible, disabled, value, ...otherProps }: GuiAddVector3Message) { +export default function Vector3Component({ + id, + hint, + label, + visible, + disabled, + value, + ...otherProps +}: GuiAddVector3Message) { const { min, max, step, precision } = otherProps; const { setValue } = React.useContext(GuiComponentContext)!; if (!visible) return <>>; - return - setValue(id, value)} - min={min} - max={max} - step={step} - precision={precision} - disabled={disabled} - /> - ; -} \ No newline at end of file + return ( + + setValue(id, value)} + min={min} + max={max} + step={step} + precision={precision} + disabled={disabled} + /> + + ); +} diff --git a/src/viser/client/src/components/common.tsx b/src/viser/client/src/components/common.tsx index 5a9bfac6b..0e2f41fbd 100644 --- a/src/viser/client/src/components/common.tsx +++ b/src/viser/client/src/components/common.tsx @@ -1,14 +1,18 @@ -import * as React from 'react'; -import { - Box, - Flex, - Text, - NumberInput, - Tooltip, -} from '@mantine/core'; +import * as React from "react"; +import { Box, Flex, Text, NumberInput, Tooltip } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; -export function ViserInputComponent({ id, label, hint, children }: { id: string, children: React.ReactNode, label?: string, hint?: string | null }) { +export function ViserInputComponent({ + id, + label, + hint, + children, +}: { + id: string; + children: React.ReactNode; + label?: string; + hint?: string | null; +}) { const { folderDepth } = React.useContext(GuiComponentContext)!; if (hint !== undefined && hint !== null) { children = // We need to add for inputs that we can't assign refs to. @@ -44,7 +48,6 @@ export function ViserInputComponent({ id, label, hint, children }: { id: string, ); } - /** GUI input with a label horizontally placed to the left of it. */ function LabeledInput(props: { id: string; @@ -80,7 +83,6 @@ function LabeledInput(props: { ); } - export function VectorInput( props: | { diff --git a/src/viser/client/src/components/utils.tsx b/src/viser/client/src/components/utils.tsx index 0ec064548..49271026b 100644 --- a/src/viser/client/src/components/utils.tsx +++ b/src/viser/client/src/components/utils.tsx @@ -14,7 +14,12 @@ export function hexToRgb(hexColor: string): [number, number, number] { const b = parseInt(hex.substring(4, 6), 16); return [r, g, b]; } -export function rgbaToHex([r, g, b, a]: [number, number, number, number]): string { +export function rgbaToHex([r, g, b, a]: [ + number, + number, + number, + number, +]): string { const hexR = r.toString(16).padStart(2, "0"); const hexG = g.toString(16).padStart(2, "0"); const hexB = b.toString(16).padStart(2, "0"); @@ -29,4 +34,4 @@ export function hexToRgba(hexColor: string): [number, number, number, number] { const b = parseInt(hex.substring(4, 6), 16); const a = parseInt(hex.substring(6, 8), 16); return [r, g, b, a]; -} \ No newline at end of file +} From 3dc583bf46bfa5aa0da6065a88a5714ebfab7926 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Mon, 5 Feb 2024 23:33:15 -0800 Subject: [PATCH 05/10] ruff --- src/viser/_gui_handles.py | 2 -- src/viser/_messages.py | 3 --- src/viser/infra/_messages.py | 16 +++++++++++++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 2857b622f..3d9490c09 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -31,8 +31,6 @@ from ._icons_enum import IconName from ._message_api import _encode_image_base64 from ._messages import ( - GuiAddMarkdownMessage, - GuiAddTabGroupMessage, GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 61fab99ed..52ed5a773 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -8,16 +8,13 @@ Any, Callable, ClassVar, - Dict, Optional, Tuple, Type, TypeVar, Union, - cast, ) -import msgpack import numpy as onp import numpy.typing as onpt from typing_extensions import Literal, override diff --git a/src/viser/infra/_messages.py b/src/viser/infra/_messages.py index 2616ce88a..3849b3272 100644 --- a/src/viser/infra/_messages.py +++ b/src/viser/infra/_messages.py @@ -34,7 +34,11 @@ def _prepare_for_deserialization(value: Any, annotation: Type) -> Any: return value ellipsis_index = args.index(...) num_ellipsis = len(value) - len(args) + 2 - args = args[:(ellipsis_index - 1)] + tuple(args[ellipsis_index - 1] for _ in range(num_ellipsis)) + args[ellipsis_index + 1 :] + args = ( + args[: (ellipsis_index - 1)] + + tuple(args[ellipsis_index - 1] for _ in range(num_ellipsis)) + + args[ellipsis_index + 1 :] + ) if len(value) != len(args): warnings.warn(f"[viser] {value} does not match annotation {annotation}") @@ -76,7 +80,11 @@ def _prepare_for_serialization(value: Any, annotation: Type) -> Any: return value ellipsis_index = args.index(...) num_ellipsis = len(value) - len(args) + 2 - args = args[:(ellipsis_index - 1)] + tuple(args[ellipsis_index - 1] for _ in range(num_ellipsis)) + args[ellipsis_index + 1 :] + args = ( + args[: (ellipsis_index - 1)] + + tuple(args[ellipsis_index - 1] for _ in range(num_ellipsis)) + + args[ellipsis_index + 1 :] + ) if len(value) != len(args): warnings.warn(f"[viser] {value} does not match annotation {annotation}") @@ -129,7 +137,9 @@ def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]: hints = get_type_hints_cached(cls) - mapping = {k: _prepare_for_deserialization(v, hints[k]) for k, v in mapping.items()} + mapping = { + k: _prepare_for_deserialization(v, hints[k]) for k, v in mapping.items() + } return mapping @classmethod From 74d780e6f5ee95fc0103258f98f958e30e6e1cc5 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 6 Feb 2024 01:06:05 -0800 Subject: [PATCH 06/10] Address CI errors --- src/viser/_gui_api.py | 38 +- .../client/src/components/MultiSlider.tsx | 117 +++ .../MultiSlider/MultiSlider.styles.tsx | 276 ++++++ .../components/MultiSlider/MultiSlider.tsx | 821 ++++++++++++++++++ src/viser/infra/_typescript_interface_gen.py | 4 +- 5 files changed, 1245 insertions(+), 11 deletions(-) create mode 100644 src/viser/client/src/components/MultiSlider.tsx create mode 100644 src/viser/client/src/components/MultiSlider/MultiSlider.styles.tsx create mode 100644 src/viser/client/src/components/MultiSlider/MultiSlider.tsx diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 323f1d9f4..6960b0b89 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -10,19 +10,38 @@ import time import warnings from pathlib import Path -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, - TypeVar, overload) +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + TypeVar, + overload, +) import numpy as onp from typing_extensions import Literal, LiteralString from . import _messages -from ._gui_handles import (GuiButtonGroupHandle, GuiButtonHandle, - GuiContainerProtocol, GuiDropdownHandle, GuiEvent, - GuiFolderHandle, GuiInputHandle, GuiMarkdownHandle, - GuiModalHandle, GuiTabGroupHandle, - SupportsRemoveProtocol, _GuiHandleState, - _GuiInputHandle, _make_unique_id) +from ._gui_handles import ( + GuiButtonGroupHandle, + GuiButtonHandle, + GuiContainerProtocol, + GuiDropdownHandle, + GuiEvent, + GuiFolderHandle, + GuiInputHandle, + GuiMarkdownHandle, + GuiModalHandle, + GuiTabGroupHandle, + SupportsRemoveProtocol, + _GuiHandleState, + _GuiInputHandle, + _make_unique_id, +) from ._icons import base64_from_icon from ._icons_enum import IconName from ._message_api import MessageApi, cast_vector @@ -938,7 +957,8 @@ def add_gui_multi_slider( max=max, step=step, value=initial_value, - visible=visible,disabled=disabled, + visible=visible, + disabled=disabled, fixed_endpoints=fixed_endpoints, precision=_compute_precision_digits(step), marks=tuple( diff --git a/src/viser/client/src/components/MultiSlider.tsx b/src/viser/client/src/components/MultiSlider.tsx new file mode 100644 index 000000000..498762a21 --- /dev/null +++ b/src/viser/client/src/components/MultiSlider.tsx @@ -0,0 +1,117 @@ +import React from "react"; +import { GuiAddMultiSliderMessage } from "../WebsocketMessages"; +import { Flex } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { ViserInputComponent } from "./common"; +import { MultiSlider } from "./MultiSlider/MultiSlider"; + +export default function MultiSliderComponent({ + id, + label, + hint, + visible, + disabled, + value, + ...otherProps +}: GuiAddMultiSliderMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>>; + const updateValue = (value: number[]) => setValue(id, value); + const { min, max, precision, step, marks, fixed_endpoints, min_range } = + otherProps; + const input = ( + + ({ + thumb: { + background: theme.fn.primaryColor(), + borderRadius: "0.1rem", + height: "0.75rem", + width: "0.625rem", + }, + trackContainer: { + zIndex: 3, + position: "relative", + }, + markLabel: { + transform: "translate(-50%, 0.03rem)", + fontSize: "0.6rem", + textAlign: "center", + }, + marksContainer: { + left: "0.2rem", + right: "0.2rem", + }, + markWrapper: { + position: "absolute", + top: `0.03rem`, + ...(marks === null + ? /* Shift the mark labels so they don't spill too far out the left/right when we only have min and max marks. */ + { + ":first-child": { + "div:nth-child(2)": { + transform: "translate(-0.2rem, 0.03rem)", + }, + }, + ":last-child": { + "div:nth-child(2)": { + transform: "translate(-90%, 0.03rem)", + }, + }, + } + : {}), + }, + mark: { + border: "0px solid transparent", + background: + theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2], + width: "0.42rem", + height: "0.42rem", + transform: `translateX(-50%)`, + }, + markFilled: { + background: disabled + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.fn.primaryColor(), + }, + })} + pt="0.2em" + showLabelOnHover={false} + min={min} + max={max} + step={step ?? undefined} + precision={precision} + value={value} + onChange={updateValue} + marks={ + marks === null + ? [ + { + value: min, + label: `${parseInt(min.toFixed(6))}`, + }, + { + value: max, + label: `${parseInt(max.toFixed(6))}`, + }, + ] + : marks + } + disabled={disabled} + fixedEndpoints={fixed_endpoints} + minRange={min_range || undefined} + /> + + ); + + return ( + {input} + ); +} diff --git a/src/viser/client/src/components/MultiSlider/MultiSlider.styles.tsx b/src/viser/client/src/components/MultiSlider/MultiSlider.styles.tsx new file mode 100644 index 000000000..f10295e3e --- /dev/null +++ b/src/viser/client/src/components/MultiSlider/MultiSlider.styles.tsx @@ -0,0 +1,276 @@ +import { createStyles, rem } from "@mantine/styles"; +import { MantineColor, getSize, MantineNumberSize } from "@mantine/styles"; + +export const sizes = { + xs: rem(4), + sm: rem(6), + md: rem(8), + lg: rem(10), + xl: rem(12), +}; + +export const useSliderRootStyles = createStyles((theme) => ({ + root: { + ...theme.fn.fontStyles(), + WebkitTapHighlightColor: "transparent", + outline: 0, + display: "flex", + flexDirection: "column", + alignItems: "center", + touchAction: "none", + position: "relative", + }, +})); + +interface ThumbStyles { + color: MantineColor; + disabled: boolean; + thumbSize: number | string; +} + +export const useThumbStyles = createStyles( + (theme, { color, disabled, thumbSize }: ThumbStyles, { size }) => ({ + label: { + position: "absolute", + top: rem(-36), + backgroundColor: + theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[9], + fontSize: theme.fontSizes.xs, + color: theme.white, + padding: `calc(${theme.spacing.xs} / 2)`, + borderRadius: theme.radius.sm, + whiteSpace: "nowrap", + pointerEvents: "none", + userSelect: "none", + touchAction: "none", + }, + + thumb: { + ...theme.fn.focusStyles(), + boxSizing: "border-box", + position: "absolute", + display: "flex", + height: thumbSize + ? rem(thumbSize) + : `calc(${getSize({ sizes, size })} * 2)`, + width: thumbSize + ? rem(thumbSize) + : `calc(${getSize({ sizes, size })} * 2)`, + backgroundColor: disabled + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.colorScheme === "dark" + ? theme.fn.themeColor(color, theme.fn.primaryShade()) + : theme.white, + border: `${rem(4)} solid ${ + disabled + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.colorScheme === "dark" + ? theme.white + : theme.fn.themeColor(color, theme.fn.primaryShade()) + }`, + color: + theme.colorScheme === "dark" + ? theme.white + : theme.fn.themeColor(color, theme.fn.primaryShade()), + transform: "translate(-50%, -50%)", + top: "50%", + cursor: disabled ? "not-allowed" : "pointer", + borderRadius: 1000, + alignItems: "center", + justifyContent: "center", + transitionDuration: "100ms", + transitionProperty: "box-shadow, transform", + transitionTimingFunction: theme.transitionTimingFunction, + zIndex: 3, + userSelect: "none", + touchAction: "none", + }, + + dragging: { + transform: "translate(-50%, -50%) scale(1.05)", + boxShadow: theme.shadows.sm, + }, + }), +); + +interface TrackStyles { + radius: MantineNumberSize; + color: MantineColor; + disabled: boolean; + inverted: boolean; + thumbSize?: number; +} + +export const useTrackStyles = createStyles( + ( + theme, + { radius, color, disabled, inverted, thumbSize }: TrackStyles, + { size }, + ) => ({ + trackContainer: { + display: "flex", + alignItems: "center", + width: "100%", + height: `calc(${getSize({ sizes, size })} * 2)`, + cursor: "pointer", + + "&:has(~ input:disabled)": { + "&": { + pointerEvents: "none", + }, + + "& .mantine-Slider-thumb": { + display: "none", + }, + + "& .mantine-Slider-track::before": { + content: '""', + backgroundColor: inverted + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2], + }, + + "& .mantine-Slider-bar": { + backgroundColor: inverted + ? theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2] + : theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4], + }, + }, + }, + + track: { + position: "relative", + height: getSize({ sizes, size }), + width: "100%", + marginRight: thumbSize ? rem(thumbSize / 2) : getSize({ size, sizes }), + marginLeft: thumbSize ? rem(thumbSize / 2) : getSize({ size, sizes }), + + "&::before": { + content: '""', + position: "absolute", + top: 0, + bottom: 0, + borderRadius: theme.fn.radius(radius), + right: `calc(${ + thumbSize ? rem(thumbSize / 2) : getSize({ size, sizes }) + } * -1)`, + left: `calc(${ + thumbSize ? rem(thumbSize / 2) : getSize({ size, sizes }) + } * -1)`, + backgroundColor: inverted + ? disabled + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.fn.variant({ variant: "filled", color }).background + : theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2], + zIndex: 0, + }, + }, + + bar: { + position: "absolute", + zIndex: 1, + top: 0, + bottom: 0, + backgroundColor: inverted + ? theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2] + : disabled + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.fn.variant({ variant: "filled", color }).background, + borderRadius: theme.fn.radius(radius), + }, + }), +); + +interface MarksStyles { + color: MantineColor; + disabled: boolean; + thumbSize?: number; +} + +export const useMarksStyles = createStyles( + (theme, { color, disabled, thumbSize }: MarksStyles, { size }) => ({ + marksContainer: { + position: "absolute", + right: thumbSize ? rem(thumbSize / 2) : getSize({ sizes, size }), + left: thumbSize ? rem(thumbSize / 2) : getSize({ sizes, size }), + + "&:has(~ input:disabled)": { + "& .mantine-Slider-markFilled": { + border: `${rem(2)} solid ${ + theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2] + }`, + borderColor: + theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4], + }, + }, + }, + + markWrapper: { + position: "absolute", + top: `calc(${rem(getSize({ sizes, size }))} / 2)`, + zIndex: 2, + height: 0, + }, + + mark: { + boxSizing: "border-box", + border: `${rem(2)} solid ${ + theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2] + }`, + height: getSize({ sizes, size }), + width: getSize({ sizes, size }), + borderRadius: 1000, + transform: `translateX(calc(-${getSize({ sizes, size })} / 2))`, + backgroundColor: theme.white, + pointerEvents: "none", + }, + + markFilled: { + borderColor: disabled + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.fn.variant({ variant: "filled", color }).background, + }, + + markLabel: { + transform: `translate(-50%, calc(${theme.spacing.xs} / 2))`, + fontSize: theme.fontSizes.sm, + color: + theme.colorScheme === "dark" + ? theme.colors.dark[2] + : theme.colors.gray[6], + whiteSpace: "nowrap", + cursor: "pointer", + userSelect: "none", + }, + }), +); diff --git a/src/viser/client/src/components/MultiSlider/MultiSlider.tsx b/src/viser/client/src/components/MultiSlider/MultiSlider.tsx new file mode 100644 index 000000000..e093d5fe4 --- /dev/null +++ b/src/viser/client/src/components/MultiSlider/MultiSlider.tsx @@ -0,0 +1,821 @@ +import React, { useRef, useState, forwardRef, useEffect } from "react"; +import { useMove, useUncontrolled } from "@mantine/hooks"; +import { + DefaultProps, + MantineNumberSize, + MantineColor, + useMantineTheme, + useComponentDefaultProps, + Selectors, +} from "@mantine/styles"; +import { MantineTransition, Box, Transition } from "@mantine/core"; +import { + useSliderRootStyles, + useThumbStyles, + useTrackStyles, + useMarksStyles, +} from "./MultiSlider.styles"; + +function getClientPosition(event: any) { + if ("TouchEvent" in window && event instanceof window.TouchEvent) { + const touch = event.touches[0]; + return touch.clientX; + } + + return event.clientX; +} + +interface GetPosition { + value: number; + min: number; + max: number; +} + +function getPosition({ value, min, max }: GetPosition) { + const position = ((value - min) / (max - min)) * 100; + return Math.min(Math.max(position, 0), 100); +} + +interface GetChangeValue { + value: number; + containerWidth?: number; + min: number; + max: number; + step: number; + precision?: number; +} + +function getChangeValue({ + value, + containerWidth, + min, + max, + step, + precision, +}: GetChangeValue) { + const left = !containerWidth + ? value + : Math.min(Math.max(value, 0), containerWidth) / containerWidth; + const dx = left * (max - min); + const nextValue = (dx !== 0 ? Math.round(dx / step) * step : 0) + min; + + const nextValueWithinStep = Math.max(nextValue, min); + + if (precision !== undefined) { + return Number(nextValueWithinStep.toFixed(precision)); + } + + return nextValueWithinStep; +} + +export type SliderRootStylesNames = Selectors; + +export interface SliderRootProps + extends DefaultProps, + React.ComponentPropsWithoutRef<"div"> { + size: MantineNumberSize; + children: React.ReactNode; + disabled: boolean; + variant: string; +} + +export const SliderRoot = forwardRef( + ( + { + className, + size, + classNames, + styles, + disabled, // eslint-disable-line @typescript-eslint/no-unused-vars + unstyled, + variant, + ...others + }: SliderRootProps, + ref, + ) => { + const { classes, cx } = useSliderRootStyles(null as unknown as void, { + name: "Slider", + classNames, + styles, + unstyled, + variant, + size, + }); + return ( + + ); + }, +); + +SliderRoot.displayName = "@mantine/core/SliderRoot"; + +export type ThumbStylesNames = Selectors; + +export interface ThumbProps extends DefaultProps { + max: number; + min: number; + value: number; + position: number; + dragging: boolean; + clicked: boolean; + color: MantineColor; + size: MantineNumberSize; + label: React.ReactNode; + onKeyDownCapture?(event: React.KeyboardEvent): void; + onMouseDown?( + event: React.MouseEvent | React.TouchEvent, + ): void; + labelTransition?: MantineTransition; + labelTransitionDuration?: number; + labelTransitionTimingFunction?: string; + labelAlwaysOn: boolean; + thumbLabel: string; + onFocus?(): void; + onBlur?(): void; + showLabelOnHover?: boolean; + isHovered?: boolean; + children?: React.ReactNode; + disabled: boolean; + thumbSize: number; + variant: string; +} + +export const Thumb = forwardRef( + ( + { + max, + min, + value, + position, + label, + dragging, + clicked, + onMouseDown, + onKeyDownCapture, + color, + classNames, + styles, + size, + labelTransition, + labelTransitionDuration, + labelTransitionTimingFunction, + labelAlwaysOn, + thumbLabel, + onFocus, + onBlur, + showLabelOnHover, + isHovered, + children = null, + disabled, + unstyled, + thumbSize, + variant, + }: ThumbProps, + ref, + ) => { + const { classes, cx, theme } = useThumbStyles( + { color, disabled, thumbSize }, + { name: "Slider", classNames, styles, unstyled, variant, size }, + ); + const [focused, setFocused] = useState(false); + + const isVisible = + labelAlwaysOn || dragging || focused || (showLabelOnHover && isHovered); + + return ( + + tabIndex={0} + role="slider" + aria-label={thumbLabel} + aria-valuemax={max} + aria-valuemin={min} + aria-valuenow={value} + ref={ref} + className={cx(classes.thumb, { [classes.dragging]: dragging })} + onFocus={() => { + setFocused(true); + typeof onFocus === "function" && onFocus(); + }} + onBlur={() => { + setFocused(false); + typeof onBlur === "function" && onBlur(); + }} + onTouchStart={onMouseDown} + onMouseDown={onMouseDown} + onKeyDownCapture={onKeyDownCapture} + onClick={(event) => event.stopPropagation()} + style={{ + [theme.dir === "rtl" ? "right" : "left"]: `${position}%`, + zIndex: clicked ? 1000 : undefined, + }} + > + {children} + + {(transitionStyles) => ( + + {label} + + )} + + + ); + }, +); + +Thumb.displayName = "@mantine/core/SliderThumb"; + +export type MarksStylesNames = Selectors; + +export interface MarksProps extends DefaultProps { + marks: { value: number; label?: React.ReactNode }[]; + size: MantineNumberSize; + thumbSize?: number; + color: MantineColor; + min: number; + max: number; + onChange(value: number): void; + disabled: boolean; + variant: string; +} + +export function Marks({ + marks, + color, + size, + thumbSize, + min, + max, + classNames, + styles, + onChange, + disabled, + unstyled, + variant, +}: MarksProps) { + const { classes, cx } = useMarksStyles( + { color, disabled, thumbSize }, + { name: "Slider", classNames, styles, unstyled, variant, size }, + ); + + const items = marks.map((mark, index) => ( + + + {mark.label && ( + { + event.stopPropagation(); + !disabled && onChange(mark.value); + }} + onTouchStart={(event) => { + event.stopPropagation(); + !disabled && onChange(mark.value); + }} + > + {mark.label} + + )} + + )); + + return {items}; +} + +Marks.displayName = "@mantine/core/SliderMarks"; + +export type TrackStylesNames = + | Selectors + | MarksStylesNames; + +export interface TrackProps extends DefaultProps { + marks: { value: number; label?: React.ReactNode }[]; + size: MantineNumberSize; + thumbSize?: number; + radius: MantineNumberSize; + color: MantineColor; + min: number; + max: number; + children: React.ReactNode; + onChange(value: number): void; + disabled: boolean; + variant: string; + containerProps?: React.PropsWithRef>; +} + +export function Track({ + size, + thumbSize, + color, + classNames, + styles, + radius, + children, + disabled, + unstyled, + variant, + containerProps, + ...others +}: TrackProps) { + const { classes } = useTrackStyles( + { color, radius, disabled, inverted: false }, + { name: "Slider", classNames, styles, unstyled, variant, size }, + ); + + return ( + <> + + {children} + + + + > + ); +} + +Track.displayName = "@mantine/core/SliderTrack"; + +export type MultiSliderStylesNames = + | SliderRootStylesNames + | ThumbStylesNames + | TrackStylesNames + | MarksStylesNames; + +type Value = number[]; + +export interface MultiSliderProps + extends DefaultProps, + Omit< + React.ComponentPropsWithoutRef<"div">, + "value" | "onChange" | "defaultValue" + > { + variant?: string; + + /** Color from theme.colors */ + color?: MantineColor; + + /** Key of theme.radius or any valid CSS value to set border-radius, theme.defaultRadius by default */ + radius?: MantineNumberSize; + + /** Predefined track and thumb size, number to set sizes */ + size?: MantineNumberSize; + + /** Minimal possible value */ + min?: number; + + /** Maximum possible value */ + max: number; + + /** Minimal range interval */ + minRange?: number; + + /** Number by which value will be incremented/decremented with thumb drag and arrows */ + step?: number; + + /** Amount of digits after the decimal point */ + precision?: number; + + /** Current value for controlled slider */ + value?: Value; + + /** Default value for uncontrolled slider */ + defaultValue?: Value; + + /** Called each time value changes */ + onChange?(value: Value): void; + + /** Called when user stops dragging slider or changes value with arrows */ + onChangeEnd?(value: Value): void; + + /** Hidden input name, use with uncontrolled variant */ + name?: string; + + /** Marks which will be placed on the track */ + marks: { value: number; label?: React.ReactNode }[]; + + /** Function to generate label or any react node to render instead, set to null to disable label */ + label?: React.ReactNode | ((value: number) => React.ReactNode); + + /** Label appear/disappear transition */ + labelTransition?: MantineTransition; + + /** Label appear/disappear transition duration in ms */ + labelTransitionDuration?: number; + + /** Label appear/disappear transition timing function, defaults to theme.transitionRimingFunction */ + labelTransitionTimingFunction?: string; + + /** If true label will be not be hidden when user stops dragging */ + labelAlwaysOn?: boolean; + + /** Thumb aria-label */ + thumbLabels?: string[]; + + /**If true slider label will appear on hover */ + showLabelOnHover?: boolean; + + /** Thumbs children, can be used to add icons */ + thumbChildren?: React.ReactNode | React.ReactNode[] | null; + + /** Disables slider */ + disabled?: boolean; + + /** Thumb width and height */ + thumbSize?: number; + + /** A transformation function, to change the scale of the slider */ + scale?: (value: number) => number; + + fixedEndpoints: boolean; +} + +const defaultProps: Partial = { + size: "md", + radius: "xl", + min: 0, + max: 100, + step: 1, + marks: [], + label: (f) => f, + labelTransition: "skew-down", + labelTransitionDuration: 0, + labelAlwaysOn: false, + thumbChildren: null, + showLabelOnHover: true, + disabled: false, + scale: (v) => v, + fixedEndpoints: false, +}; + +export const MultiSlider = forwardRef( + (props, ref) => { + const { + classNames, + styles, + color, + value, + onChange, + onChangeEnd, + size, + radius, + min, + max, + minRange, + step, + precision, + defaultValue, + name, + marks, + label, + labelTransition, + labelTransitionDuration, + labelTransitionTimingFunction, + labelAlwaysOn, + thumbLabels, + showLabelOnHover, + thumbChildren, + disabled, + unstyled, + thumbSize, + scale, + variant, + fixedEndpoints, + ...others + } = useComponentDefaultProps("MultiSlider", defaultProps, props) as any; + const _minRange = minRange || step; + + const theme = useMantineTheme(); + const [focused, setFocused] = useState(-1); + const [hovered, setHovered] = useState(false); + const [_value, setValue] = useUncontrolled({ + value, + defaultValue, + finalValue: [min, max], + onChange, + }); + const valueRef = useRef(_value); + const thumbs = useRef<(HTMLDivElement | null)[]>([]); + const thumbIndex = useRef(-1); + const positions = _value.map((x) => getPosition({ value: x, min, max })); + + const _setValue = (val: Value) => { + setValue(val); + valueRef.current = val; + }; + + useEffect( + () => { + if (Array.isArray(value)) { + valueRef.current = value; + } + }, + Array.isArray(value) ? [value[0], value[1]] : [null, null], + ); + + const setRangedValue = ( + val: number, + index: number, + triggerChangeEnd: boolean, + ) => { + const clone: Value = [...valueRef.current]; + clone[index] = val; + + if (index < clone.length - 1) { + if (val > clone[index + 1] - (_minRange - 0.000000001)) { + clone[index] = Math.max(min, clone[index + 1] - _minRange); + } + + if (val > (max - (_minRange - 0.000000001) || min)) { + clone[index] = valueRef.current[index]; + } + } + + if (index > 0) { + if (val < clone[index - 1] + _minRange) { + clone[index] = Math.min(max, clone[index - 1] + _minRange); + } + } + + if (fixedEndpoints && (index === 0 || index == clone.length - 1)) { + clone[index] = valueRef.current[index]; + } + + _setValue(clone); + + if (triggerChangeEnd) { + onChangeEnd?.(valueRef.current); + } + }; + + const handleChange = (val: number) => { + if (!disabled) { + const nextValue = getChangeValue({ + value: val, + min, + max, + step, + precision, + }); + setRangedValue(nextValue, thumbIndex.current, false); + } + }; + + const { ref: container, active } = useMove( + ({ x }) => handleChange(x), + { onScrubEnd: () => onChangeEnd?.(valueRef.current) }, + theme.dir, + ); + + function handleThumbMouseDown(index: number) { + thumbIndex.current = index; + } + + const handleTrackMouseDownCapture = ( + event: + | React.MouseEvent + | React.TouchEvent, + ) => { + container.current.focus(); + const rect = container.current.getBoundingClientRect(); + const changePosition = getClientPosition(event.nativeEvent); + const changeValue = getChangeValue({ + value: changePosition - rect.left, + max, + min, + step, + containerWidth: rect.width, + }); + + const _nearestHandle = _value + .map((v) => Math.abs(v - changeValue)) + .indexOf(Math.min(..._value.map((v) => Math.abs(v - changeValue)))); + + thumbIndex.current = _nearestHandle; + }; + + const getFocusedThumbIndex = () => { + if (focused !== 1 && focused !== 0) { + setFocused(0); + return 0; + } + + return focused; + }; + + const handleTrackKeydownCapture = ( + event: React.KeyboardEvent, + ) => { + if (!disabled) { + switch (event.key) { + case "ArrowUp": { + event.preventDefault(); + const focusedIndex = getFocusedThumbIndex(); + thumbs.current[focusedIndex]?.focus(); + setRangedValue( + Math.min( + Math.max(valueRef.current[focusedIndex] + step, min), + max, + ), + focusedIndex, + true, + ); + break; + } + case "ArrowRight": { + event.preventDefault(); + const focusedIndex = getFocusedThumbIndex(); + thumbs.current[focusedIndex]?.focus(); + setRangedValue( + Math.min( + Math.max( + theme.dir === "rtl" + ? valueRef.current[focusedIndex] - step + : valueRef.current[focusedIndex] + step, + min, + ), + max, + ), + focusedIndex, + true, + ); + break; + } + + case "ArrowDown": { + event.preventDefault(); + const focusedIndex = getFocusedThumbIndex(); + thumbs.current[focusedIndex]?.focus(); + setRangedValue( + Math.min( + Math.max(valueRef.current[focusedIndex] - step, min), + max, + ), + focusedIndex, + true, + ); + break; + } + case "ArrowLeft": { + event.preventDefault(); + const focusedIndex = getFocusedThumbIndex(); + thumbs.current[focusedIndex]?.focus(); + setRangedValue( + Math.min( + Math.max( + theme.dir === "rtl" + ? valueRef.current[focusedIndex] + step + : valueRef.current[focusedIndex] - step, + min, + ), + max, + ), + focusedIndex, + true, + ); + break; + } + + default: { + break; + } + } + } + }; + + const sharedThumbProps = { + max, + min, + color, + size, + labelTransition, + labelTransitionDuration, + labelTransitionTimingFunction, + labelAlwaysOn, + onBlur: () => setFocused(-1), + classNames, + styles, + }; + + const hasArrayThumbChildren = Array.isArray(thumbChildren); + + return ( + + { + const nearestValue = + Math.abs(_value[0] - val) > Math.abs(_value[1] - val) ? 1 : 0; + const clone: Value = [..._value]; + clone[nearestValue] = val; + _setValue(clone); + }} + disabled={disabled} + unstyled={unstyled} + variant={variant} + containerProps={{ + ref: container, + onMouseEnter: showLabelOnHover ? () => setHovered(true) : undefined, + onMouseLeave: showLabelOnHover + ? () => setHovered(false) + : undefined, + onTouchStartCapture: handleTrackMouseDownCapture, + onTouchEndCapture: () => { + thumbIndex.current = -1; + }, + onMouseDownCapture: handleTrackMouseDownCapture, + onMouseUpCapture: () => { + thumbIndex.current = -1; + }, + onKeyDownCapture: handleTrackKeydownCapture, + }} + > + {_value.map((value, index) => ( + { + thumbs.current[index] = node; + }} + thumbLabel={thumbLabels ? thumbLabels[index] : ""} + onMouseDown={() => handleThumbMouseDown(index)} + onFocus={() => setFocused(index)} + showLabelOnHover={showLabelOnHover} + isHovered={hovered} + disabled={disabled} + unstyled={unstyled} + thumbSize={thumbSize} + variant={variant} + > + {hasArrayThumbChildren ? thumbChildren[index] : thumbChildren} + + ))} + + {_value.map((value, index) => ( + + ))} + + ); + }, +); + +MultiSlider.displayName = "MultiSlider"; diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index 8bbe8dbcc..3204d9a33 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -127,8 +127,8 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: # Generate union type over all tags. for tag, cls_names in tag_map.items(): out_lines.append(f"export type {tag} = ") - for cls in cls_names: - out_lines.append(f" | {cls}") + for cls_name in cls_names: + out_lines.append(f" | {cls_name}") out_lines[-1] = out_lines[-1] + ";" interfaces = "\n".join(out_lines) + "\n" From 7814af316e12df33553bca37ae131a0c9fc19365 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 6 Feb 2024 01:56:23 -0800 Subject: [PATCH 07/10] Fix tab groups, spacing tweaks --- src/viser/_gui_api.py | 14 ++++++-- src/viser/_gui_handles.py | 5 ++- src/viser/client/src/components/Folder.tsx | 7 ++-- .../client/src/components/MultiSlider.tsx | 6 ++-- src/viser/client/src/components/Slider.tsx | 3 +- src/viser/infra/_messages.py | 32 ++++--------------- 6 files changed, 29 insertions(+), 38 deletions(-) diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 6960b0b89..f3c68a4ae 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -282,15 +282,25 @@ def add_gui_tab_group( """ tab_group_id = _make_unique_id() order = _apply_default_order(order) + + self._get_api()._queue( + _messages.GuiAddTabGroupMessage( + order=order, + id=tab_group_id, + container_id=self._get_container_id(), + tab_labels=(), + visible=visible, + tab_icons_base64=(), + tab_container_ids=(), + ) + ) return GuiTabGroupHandle( _tab_group_id=tab_group_id, _labels=[], _icons_base64=[], _tabs=[], _gui_api=self, - _container_id=self._get_container_id(), _order=order, - _visible=visible, ) def add_gui_markdown( diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 54986f91a..906439c8e 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -30,6 +30,7 @@ from ._icons_enum import IconName from ._message_api import _encode_image_base64 from ._messages import ( + GuiAddTabGroupMessage, GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, @@ -325,9 +326,7 @@ class GuiTabGroupHandle: _icons_base64: List[Optional[str]] _tabs: List[GuiTabHandle] _gui_api: GuiApi - _container_id: str # Parent. _order: float - _visible: bool @property def order(self) -> float: @@ -356,7 +355,7 @@ def remove(self) -> None: self._gui_api._get_api()._queue(GuiRemoveMessage(self._tab_group_id)) def _sync_with_client(self) -> None: - """Send a message that syncs tab state with the client.""" + """Send messages for syncing tab state with the client.""" self._gui_api._get_api()._queue( GuiUpdateMessage(self._tab_group_id, "tab_labels", tuple(self._labels)) ) diff --git a/src/viser/client/src/components/Folder.tsx b/src/viser/client/src/components/Folder.tsx index 1b555927d..40e9bc834 100644 --- a/src/viser/client/src/components/Folder.tsx +++ b/src/viser/client/src/components/Folder.tsx @@ -26,8 +26,11 @@ export default function FolderComponent({ pt="0.0625em" mx="xs" mt="xs" - mb="sm" - sx={{ position: "relative" }} + sx={{ + position: "relative", + marginBottom: "1.5em", + ":last-child": { marginBottom: "0.375em" }, + }} > + - + ); return ( diff --git a/src/viser/client/src/components/Slider.tsx b/src/viser/client/src/components/Slider.tsx index 756e336fc..4e7ca0a81 100644 --- a/src/viser/client/src/components/Slider.tsx +++ b/src/viser/client/src/components/Slider.tsx @@ -82,6 +82,7 @@ export default function SliderComponent({ }, })} pt="0.2em" + pb="0.4em" showLabelOnHover={false} min={min} max={max} @@ -131,8 +132,6 @@ export default function SliderComponent({ ); - const containerProps = {}; - return ( {input} ); diff --git a/src/viser/infra/_messages.py b/src/viser/infra/_messages.py index 3849b3272..e4c68a535 100644 --- a/src/viser/infra/_messages.py +++ b/src/viser/infra/_messages.py @@ -28,19 +28,9 @@ def _prepare_for_deserialization(value: Any, annotation: Type) -> Any: elif get_origin(annotation) is tuple: out = [] args = get_args(annotation) - if ... in args: - if len(value) < len(args) - 1: - warnings.warn(f"[viser] {value} does not match annotation {annotation}") - return value - ellipsis_index = args.index(...) - num_ellipsis = len(value) - len(args) + 2 - args = ( - args[: (ellipsis_index - 1)] - + tuple(args[ellipsis_index - 1] for _ in range(num_ellipsis)) - + args[ellipsis_index + 1 :] - ) - - if len(value) != len(args): + if len(args) >= 2 and args[1] == ...: + args = (args[0],) * len(value) + elif len(value) != len(args): warnings.warn(f"[viser] {value} does not match annotation {annotation}") return value @@ -74,19 +64,9 @@ def _prepare_for_serialization(value: Any, annotation: Type) -> Any: out = [] args = get_args(annotation) - if ... in args: - if len(value) < len(args) - 1: - warnings.warn(f"[viser] {value} does not match annotation {annotation}") - return value - ellipsis_index = args.index(...) - num_ellipsis = len(value) - len(args) + 2 - args = ( - args[: (ellipsis_index - 1)] - + tuple(args[ellipsis_index - 1] for _ in range(num_ellipsis)) - + args[ellipsis_index + 1 :] - ) - - if len(value) != len(args): + if len(args) >= 2 and args[1] == ...: + args = (args[0],) * len(value) + elif len(value) != len(args): warnings.warn(f"[viser] {value} does not match annotation {annotation}") return value From 6c7de6d551b69659309270d9da7c5eef4b593b4b Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 6 Feb 2024 02:14:00 -0800 Subject: [PATCH 08/10] Fix performance regression (components were unnecessarily re-rendering) --- .../client/src/ControlPanel/Generated.tsx | 67 ++++++++----------- src/viser/client/src/components/Folder.tsx | 3 +- 2 files changed, 30 insertions(+), 40 deletions(-) diff --git a/src/viser/client/src/ControlPanel/Generated.tsx b/src/viser/client/src/ControlPanel/Generated.tsx index ae121cc84..5bb2bb02c 100644 --- a/src/viser/client/src/ControlPanel/Generated.tsx +++ b/src/viser/client/src/ControlPanel/Generated.tsx @@ -1,6 +1,5 @@ import { ViewerContext } from "../App"; import { makeThrottledMessageSender } from "../WebsocketFunctions"; -import { GuiConfig } from "./GuiState"; import { GuiComponentContext } from "./GuiComponentContext"; import { Box } from "@mantine/core"; @@ -21,42 +20,6 @@ import TabGroupComponent from "../components/TabGroup"; import FolderComponent from "../components/Folder"; import MultiSliderComponent from "../components/MultiSlider"; -function GuiContainer({ containerId }: { containerId: string }) { - const viewer = React.useContext(ViewerContext)!; - - const guiIdSet = - viewer.useGui((state) => state.guiIdSetFromContainerId[containerId]) ?? {}; - - // Render each GUI element in this container. - const guiIdArray = [...Object.keys(guiIdSet)]; - const guiOrderFromId = viewer!.useGui((state) => state.guiOrderFromId); - if (guiIdSet === undefined) return null; - - let guiIdOrderPairArray = guiIdArray.map((id) => ({ - id: id, - order: guiOrderFromId[id], - })); - let pb = undefined; - guiIdOrderPairArray = guiIdOrderPairArray.sort((a, b) => a.order - b.order); - const inputProps = viewer.useGui((state) => - guiIdOrderPairArray.map((pair) => state.guiConfigFromId[pair.id]), - ); - const lastProps = inputProps && inputProps[inputProps.length - 1]; - - // Done to match the old behaviour. Is it still needed? - if (lastProps !== undefined && lastProps.type === "GuiAddFolderMessage") { - pb = "0.125em"; - } - const out = ( - - {inputProps.map((conf) => ( - - ))} - - ); - return out; -} - /** Root of generated inputs. */ export default function GeneratedGuiContainer({ containerId, @@ -90,8 +53,36 @@ export default function GeneratedGuiContainer({ ); } +function GuiContainer({ containerId }: { containerId: string }) { + const viewer = React.useContext(ViewerContext)!; + + const guiIdSet = + viewer.useGui((state) => state.guiIdSetFromContainerId[containerId]) ?? {}; + + // Render each GUI element in this container. + const guiIdArray = [...Object.keys(guiIdSet)]; + const guiOrderFromId = viewer!.useGui((state) => state.guiOrderFromId); + if (guiIdSet === undefined) return null; + + let guiIdOrderPairArray = guiIdArray.map((id) => ({ + id: id, + order: guiOrderFromId[id], + })); + guiIdOrderPairArray = guiIdOrderPairArray.sort((a, b) => a.order - b.order); + const out = ( + + {guiIdOrderPairArray.map((pair) => ( + + ))} + + ); + return out; +} + /** A single generated GUI element. */ -function GeneratedInput(conf: GuiConfig) { +function GeneratedInput(props: { guiId: string }) { + const viewer = React.useContext(ViewerContext)!; + const conf = viewer.useGui((state) => state.guiConfigFromId[props.guiId]); switch (conf.type) { case "GuiAddFolderMessage": return ; diff --git a/src/viser/client/src/components/Folder.tsx b/src/viser/client/src/components/Folder.tsx index 40e9bc834..81fecd1b3 100644 --- a/src/viser/client/src/components/Folder.tsx +++ b/src/viser/client/src/components/Folder.tsx @@ -28,8 +28,7 @@ export default function FolderComponent({ mt="xs" sx={{ position: "relative", - marginBottom: "1.5em", - ":last-child": { marginBottom: "0.375em" }, + ":not(:last-child)": { marginBottom: "1em" }, }} > Date: Tue, 6 Feb 2024 02:17:42 -0800 Subject: [PATCH 09/10] ruff --- src/viser/_gui_handles.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 906439c8e..1e22ea94c 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -30,7 +30,6 @@ from ._icons_enum import IconName from ._message_api import _encode_image_base64 from ._messages import ( - GuiAddTabGroupMessage, GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, From 7922006f08b10eacd53fa54e31c10e070192e20e Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 6 Feb 2024 02:18:32 -0800 Subject: [PATCH 10/10] move multislider source --- src/viser/client/src/components/MultiSlider.tsx | 2 +- .../MultiSlider.styles.tsx | 0 .../{MultiSlider => MultiSliderPrimitive}/MultiSlider.tsx | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename src/viser/client/src/components/{MultiSlider => MultiSliderPrimitive}/MultiSlider.styles.tsx (100%) rename src/viser/client/src/components/{MultiSlider => MultiSliderPrimitive}/MultiSlider.tsx (100%) diff --git a/src/viser/client/src/components/MultiSlider.tsx b/src/viser/client/src/components/MultiSlider.tsx index 596005a5d..726a2d337 100644 --- a/src/viser/client/src/components/MultiSlider.tsx +++ b/src/viser/client/src/components/MultiSlider.tsx @@ -3,7 +3,7 @@ import { GuiAddMultiSliderMessage } from "../WebsocketMessages"; import { Box } from "@mantine/core"; import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; import { ViserInputComponent } from "./common"; -import { MultiSlider } from "./MultiSlider/MultiSlider"; +import { MultiSlider } from "./MultiSliderPrimitive/MultiSlider"; export default function MultiSliderComponent({ id, diff --git a/src/viser/client/src/components/MultiSlider/MultiSlider.styles.tsx b/src/viser/client/src/components/MultiSliderPrimitive/MultiSlider.styles.tsx similarity index 100% rename from src/viser/client/src/components/MultiSlider/MultiSlider.styles.tsx rename to src/viser/client/src/components/MultiSliderPrimitive/MultiSlider.styles.tsx diff --git a/src/viser/client/src/components/MultiSlider/MultiSlider.tsx b/src/viser/client/src/components/MultiSliderPrimitive/MultiSlider.tsx similarity index 100% rename from src/viser/client/src/components/MultiSlider/MultiSlider.tsx rename to src/viser/client/src/components/MultiSliderPrimitive/MultiSlider.tsx