From d5d003d7f92fdb56c69eff5e7d78ae45cec1b36c Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 23 Feb 2024 09:28:22 -0600 Subject: [PATCH] fix: Type fixes and require pyright (#302) Fixed remaining type issues and makes pyright a mandatory check via the pre-commit. Fixes #97 --------- Co-authored-by: Mike Bender --- .pre-commit-config.yaml | 17 ++++ README.md | 7 +- .../deephaven/plugin/matplotlib/__init__.py | 4 +- .../communication/DeephavenFigureListener.py | 12 ++- .../plot/express/data/data_generators.py | 20 +++- .../plot/express/data_mapping/DataMapping.py | 2 +- .../plot/express/data_mapping/data_mapping.py | 6 +- .../express/data_mapping/json_conversion.py | 4 +- .../deephaven_figure/DeephavenFigure.py | 95 +++++++++++-------- .../deephaven_figure/RevisionManager.py | 4 +- .../plot/express/deephaven_figure/generate.py | 34 ++++--- .../plot/express/plots/PartitionManager.py | 68 ++++++++----- .../deephaven/plot/express/plots/_layer.py | 70 +++++++++----- .../plot/express/plots/_private_utils.py | 26 +++-- .../plot/express/plots/distribution.py | 24 +---- .../src/deephaven/plot/express/plots/line.py | 4 +- .../src/deephaven/plot/express/plots/maps.py | 4 +- .../deephaven/plot/express/plots/scatter.py | 2 +- .../deephaven/plot/express/plots/subplots.py | 93 ++++++++++++------ .../preprocess/AttachedPreprocessor.py | 13 ++- .../express/preprocess/FreqPreprocessor.py | 2 +- .../express/preprocess/HistPreprocessor.py | 11 ++- .../plot/express/preprocess/Preprocessor.py | 9 +- .../preprocess/UnivariatePreprocessor.py | 4 +- .../plot/express/preprocess/preprocess.py | 5 +- .../deephaven/plot/express/shared/shared.py | 3 +- .../deephaven/ui/_internal/RenderContext.py | 26 ++++- .../ui/src/deephaven/ui/_internal/utils.py | 13 ++- .../ui/src/deephaven/ui/components/html.py | 2 +- .../ui/src/deephaven/ui/components/icon.py | 2 +- .../ui/src/deephaven/ui/elements/UITable.py | 28 +++--- .../ui/src/deephaven/ui/hooks/use_callback.py | 10 +- .../src/deephaven/ui/hooks/use_cell_data.py | 2 +- .../src/deephaven/ui/hooks/use_column_data.py | 4 +- .../ui/src/deephaven/ui/hooks/use_effect.py | 13 ++- .../ui/hooks/use_execution_context.py | 6 +- .../deephaven/ui/hooks/use_liveness_scope.py | 6 +- plugins/ui/src/deephaven/ui/hooks/use_memo.py | 10 +- plugins/ui/src/deephaven/ui/hooks/use_ref.py | 9 +- .../ui/src/deephaven/ui/hooks/use_row_data.py | 2 +- .../ui/src/deephaven/ui/hooks/use_row_list.py | 2 +- .../src/deephaven/ui/hooks/use_table_data.py | 4 +- .../deephaven/ui/hooks/use_table_listener.py | 8 +- .../ui/object_types/DashboardType.py | 6 +- plugins/ui/src/deephaven/ui/types/types.py | 8 +- ruff.toml | 5 + 46 files changed, 445 insertions(+), 264 deletions(-) create mode 100644 ruff.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a164c92d0..ef4506e65 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,3 +9,20 @@ repos: rev: 22.10.0 hooks: - id: black + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.334 + hooks: + - id: pyright + files: plugins.*\/src.*\.py + additional_dependencies: + [ + pandas, + deephaven-core, + plotly, + json-rpc, + matplotlib + ] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.2 + hooks: + - id: ruff diff --git a/README.md b/README.md index 093fb0c0e..bba85aea3 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Start by setting up the python venv and pre-commit hooks. ### Pre-commit hooks/Python formatting -Black and blacken-docs formatting is setup through a pre-commit hook. +Black and blacken-docs formatting, pyright type checking, and ruff linting is setup through a pre-commit hook. To install the pre-commit hooks, run the following commands from the root directory of this repo: ```shell @@ -56,6 +56,11 @@ pre-commit run --all-files All steps should pass. +To bypass the pre-commit hook, you can commit with the `--no-verify` flag, for example: +```shell +git commit --no-verify -m "commit message"` +``` + ### Running Python tests The above steps will also set up `tox` to run tests for the python plugins that support it. diff --git a/plugins/matplotlib/src/deephaven/plugin/matplotlib/__init__.py b/plugins/matplotlib/src/deephaven/plugin/matplotlib/__init__.py index 788617a5c..17491fe08 100644 --- a/plugins/matplotlib/src/deephaven/plugin/matplotlib/__init__.py +++ b/plugins/matplotlib/src/deephaven/plugin/matplotlib/__init__.py @@ -142,7 +142,7 @@ def __init__(self, fig, table, func, columns=None, fargs=None, **kwargs): super().__init__(fig, event_source, **kwargs) # Start the animation right away - self._start() + self._start() # type: ignore def new_frame_seq(self): """ @@ -156,7 +156,7 @@ def _step(self, update, *args): # Extends the _step() method for the Animation class. Used # to get the update information self._last_update = update - return super()._step(*args) + return super()._step(*args) # type: ignore def _draw_frame(self, framedata): data = {} diff --git a/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py b/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py index 4b26cad6e..288d8dfa1 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py +++ b/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py @@ -72,7 +72,7 @@ def _setup_listeners(self) -> None: self._handles.append(handle) self._liveness_scope.manage(handle) - def _get_figure(self) -> DeephavenFigure: + def _get_figure(self) -> DeephavenFigure | None: """ Get the current figure @@ -96,10 +96,9 @@ def _on_update( if self._connection: revision = self._revision_manager.get_revision() node.recreate_figure() + figure = self._get_figure() try: - self._connection.on_data( - *self._build_figure_message(self._get_figure(), revision) - ) + self._connection.on_data(*self._build_figure_message(figure, revision)) except RuntimeError: # trying to send data when the connection is closed, ignore pass @@ -115,7 +114,7 @@ def _handle_retrieve_figure(self) -> tuple[bytes, list[Any]]: return self._build_figure_message(self._get_figure()) def _build_figure_message( - self, figure: DeephavenFigure, revision: int | None = None + self, figure: DeephavenFigure | None, revision: int | None = None ) -> tuple[bytes, list[Any]]: """ Build a message to send to the client with the current figure. @@ -129,6 +128,9 @@ def _build_figure_message( """ exporter = self._exporter + if not figure: + raise ValueError("Figure is None") + with self._revision_manager: # if revision is None, just send the figure if revision is not None: diff --git a/plugins/plotly-express/src/deephaven/plot/express/data/data_generators.py b/plugins/plotly-express/src/deephaven/plot/express/data/data_generators.py index 5a0d5ae36..b3118a58c 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/data/data_generators.py +++ b/plugins/plotly-express/src/deephaven/plot/express/data/data_generators.py @@ -15,6 +15,20 @@ MINUTE = 60 * SECOND #: One minute in nanoseconds. +def _cast_timestamp(time: pd.Timestamp | None) -> pd.Timestamp: + """ + Casts a pd.Timestamp to be non-None. + Args: + time: the timestamp to cast + + Returns: + the timestamp + """ + if not time: + raise ValueError("pd_base_time is None") + return time + + def iris(ticking: bool = True, size: int = 300) -> Table: """ Returns a ticking version of the 1936 Iris flower dataset. @@ -55,7 +69,7 @@ def iris(ticking: bool = True, size: int = 300) -> Table: """ base_time = to_j_instant("1936-01-01T08:00:00 UTC") - pd_base_time = to_pd_timestamp(base_time) + pd_base_time = _cast_timestamp(to_pd_timestamp(base_time)) species_list: list[str] = ["setosa", "versicolor", "virginica"] col_ids = {"sepal_length": 0, "sepal_width": 1, "petal_length": 2, "petal_width": 3} @@ -153,7 +167,9 @@ def stocks(ticking: bool = True, hours_of_data: int = 1) -> Table: base_time = to_j_instant( "2018-06-01T08:00:00 ET" ) # day deephaven.io was registered - pd_base_time = to_pd_timestamp(base_time) + + pd_base_time = _cast_timestamp(to_pd_timestamp(base_time)) + sym_list = ["CAT", "DOG", "FISH", "BIRD", "LIZARD"] sym_dict = {v: i for i, v in enumerate(sym_list)} sym_weights = [95, 100, 70, 45, 35] diff --git a/plugins/plotly-express/src/deephaven/plot/express/data_mapping/DataMapping.py b/plugins/plotly-express/src/deephaven/plot/express/data_mapping/DataMapping.py index 17165638d..8ece0f772 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/data_mapping/DataMapping.py +++ b/plugins/plotly-express/src/deephaven/plot/express/data_mapping/DataMapping.py @@ -4,7 +4,7 @@ from typing import Any from deephaven.table import Table -from deephaven.plugin.object_type import Exporter +from ..exporter import Exporter from .json_conversion import json_link_mapping diff --git a/plugins/plotly-express/src/deephaven/plot/express/data_mapping/data_mapping.py b/plugins/plotly-express/src/deephaven/plot/express/data_mapping/data_mapping.py index f05bf81e3..4320294c1 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/data_mapping/data_mapping.py +++ b/plugins/plotly-express/src/deephaven/plot/express/data_mapping/data_mapping.py @@ -116,7 +116,11 @@ def add_custom_data_args( generators = [] for arg in CUSTOM_DATA_ARGS: - if arg in custom_call_args and (val := custom_call_args[arg]): + if ( + custom_call_args + and arg in custom_call_args + and (val := custom_call_args[arg]) + ): generators.append(custom_data_args_generator(arg, val)) update_generator = combined_generator(generators, fill={}) diff --git a/plugins/plotly-express/src/deephaven/plot/express/data_mapping/json_conversion.py b/plugins/plotly-express/src/deephaven/plot/express/data_mapping/json_conversion.py index 4d71d3829..1061d5c52 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/data_mapping/json_conversion.py +++ b/plugins/plotly-express/src/deephaven/plot/express/data_mapping/json_conversion.py @@ -3,7 +3,7 @@ from collections import defaultdict from itertools import count from collections.abc import Generator, Iterable -from typing import Any +from typing import Any, Mapping def json_links(i: int, _vars: Iterable[str]) -> Generator[str, None, None]: @@ -24,7 +24,7 @@ def json_links(i: int, _vars: Iterable[str]) -> Generator[str, None, None]: def convert_to_json_links( var_col_dicts: list[dict[str, str]], start_index: int -) -> Generator[dict[str, str], None, None]: +) -> Generator[Mapping[str, list[str]], None, None]: """Convert the provided dictionaries to json links Args: diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/DeephavenFigure.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/DeephavenFigure.py index 55082049e..3166bfdc0 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/DeephavenFigure.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/DeephavenFigure.py @@ -8,7 +8,7 @@ from copy import copy from deephaven.table import PartitionedTable, Table -from deephaven.execution_context import ExecutionContext +from deephaven.execution_context import ExecutionContext, get_exec_ctx from deephaven.liveness_scope import LivenessScope from ..shared import args_copy @@ -61,8 +61,12 @@ class DeephavenNode: represents a node in the graph. """ + def __init__(self): + self.cached_figure = None + self.parent = None + @abstractmethod - def recreate_figure(self) -> None: + def recreate_figure(self, update_parent: bool = True) -> None: """ Recreate the figure. This is called when an underlying partition or child figure changes. @@ -132,10 +136,10 @@ def __init__( func: The function to call """ self.parent = parent - self.exec_ctx = exec_ctx - self.args = args + self.exec_ctx = exec_ctx if exec_ctx else get_exec_ctx() + self.args = args if args else {} self.table = table - self.func = func + self.func = func if func else lambda **kwargs: None self.cached_figure = None self.revision_manager = RevisionManager() @@ -161,7 +165,7 @@ def recreate_figure(self, update_parent: bool = True) -> None: if self.revision_manager.updated_revision(revision): self.cached_figure = new_figure - if update_parent: + if update_parent and self.parent: self.parent.recreate_figure() def copy( @@ -195,7 +199,7 @@ def copy( new_node.parent = parent return new_node - def get_figure(self) -> DeephavenFigure: + def get_figure(self) -> DeephavenFigure | None: """ Get the figure for this node. It will be generated if not cached @@ -223,7 +227,12 @@ class DeephavenLayerNode(DeephavenNode): """ def __init__( - self, layer_func: Callable, args: dict[str, Any], exec_ctx: ExecutionContext + self, + layer_func: Callable, + args: dict[str, Any], + exec_ctx: ExecutionContext, + cached_figure: DeephavenFigure | None = None, + parent: DeephavenLayerNode | DeephavenHeadNode | None = None, ): """ Create a new DeephavenLayerNode @@ -233,15 +242,15 @@ def __init__( args: The arguments to the function exec_ctx: The execution context """ - self.parent = None + self.parent = parent self.nodes = [] self.layer_func = layer_func self.args = args - self.cached_figure = None + self.cached_figure = cached_figure self.exec_ctx = exec_ctx self.revision_manager = RevisionManager() - def recreate_figure(self, update_parent=True) -> None: + def recreate_figure(self, update_parent: bool = True) -> None: """ Recreate the figure. This is called when the underlying partition or a child node changes @@ -262,12 +271,12 @@ def recreate_figure(self, update_parent=True) -> None: if self.revision_manager.updated_revision(revision): self.cached_figure = new_figure - if update_parent: + if update_parent and self.parent: self.parent.recreate_figure() def copy( self, - parent: DeephavenNode | DeephavenHeadNode, + parent: DeephavenLayerNode | DeephavenHeadNode, partitioned_tables: dict[int, tuple[PartitionedTable, DeephavenNode]], ) -> DeephavenLayerNode: """ @@ -282,15 +291,15 @@ def copy( Returns: DeephavenLayerNode: The new node """ - new_node = DeephavenLayerNode(self.layer_func, self.args, self.exec_ctx) + new_node = DeephavenLayerNode( + self.layer_func, self.args, self.exec_ctx, self.cached_figure, parent + ) new_node.nodes = [ node.copy(new_node, partitioned_tables) for node in self.nodes ] - new_node.cached_figure = self.cached_figure - new_node.parent = parent return new_node - def get_figure(self) -> DeephavenFigure: + def get_figure(self) -> DeephavenFigure | None: """ Get the figure for this node. It will be generated if not cached @@ -322,7 +331,7 @@ def __init__(self): Create a new DeephavenHeadNode """ # there is only one child node of the head, either a layer or a figure - self.node = None + self.node: DeephavenNode | None = None self.partitioned_tables = {} self.cached_figure = None @@ -335,7 +344,8 @@ def copy_graph(self) -> DeephavenHeadNode: """ new_head = DeephavenHeadNode() new_partitioned_tables = copy(self.partitioned_tables) - new_head.node = self.node.copy(new_head, new_partitioned_tables) + if self.node: + new_head.node = self.node.copy(new_head, new_partitioned_tables) new_head.partitioned_tables = new_partitioned_tables return new_head @@ -344,10 +354,11 @@ def recreate_figure(self) -> None: Recreate the figure. This is called when the underlying partition or a child node changes """ - self.node.recreate_figure(update_parent=False) - self.cached_figure = self.node.cached_figure + if self.node: + self.node.recreate_figure(update_parent=False) + self.cached_figure = self.node.cached_figure - def get_figure(self) -> DeephavenFigure: + def get_figure(self) -> DeephavenFigure | None: """ Get the figure for this node. This will be called by a communication to get the initial figure. @@ -472,7 +483,9 @@ def to_json(self: DeephavenFigure, exporter: Exporter) -> str: The DeephavenFigure as a JSON string """ - plotly = json.loads(self._plotly_fig.to_json()) + plotly = None + if self._plotly_fig and (fig_json := self._plotly_fig.to_json()) is not None: + plotly = json.loads(fig_json) mappings = self.get_json_links(exporter) deephaven = { "mappings": mappings, @@ -529,7 +542,7 @@ def add_figure_to_graph( exec_ctx: ExecutionContext, args: dict[str, Any], table: Table | PartitionedTable, - key_column_table: Table, + key_column_table: Table | None, func: Callable, ) -> None: """ @@ -564,7 +577,7 @@ def get_head_node(self) -> DeephavenHeadNode: """ return self._head_node - def get_figure(self) -> DeephavenFigure: + def get_figure(self) -> DeephavenFigure | None: """ Get the true DeephavenFigure for this figure. @@ -573,16 +586,17 @@ def get_figure(self) -> DeephavenFigure: """ return self._head_node.get_figure() - def get_plotly_fig(self) -> Figure: + def get_plotly_fig(self) -> Figure | None: """ Get the plotly figure for this figure Returns: The plotly figure """ - if not self.get_figure(): + figure = self.get_figure() + if not figure: return self._plotly_fig - return self.get_figure().get_plotly_fig() + return figure.get_plotly_fig() def get_data_mappings(self) -> list[DataMapping]: """ @@ -591,9 +605,10 @@ def get_data_mappings(self) -> list[DataMapping]: Returns: The data mappings """ - if not self.get_figure(): + figure = self.get_figure() + if not figure: return self._data_mappings - return self.get_figure().get_data_mappings() + return figure.get_data_mappings() def get_trace_generator(self) -> Generator[dict[str, Any], None, None] | None: """ @@ -602,9 +617,10 @@ def get_trace_generator(self) -> Generator[dict[str, Any], None, None] | None: Returns: The trace generator """ - if not self.get_figure(): + figure = self.get_figure() + if not figure: return self._trace_generator - return self.get_figure().get_trace_generator() + return figure.get_trace_generator() def get_has_template(self) -> bool: """ @@ -613,9 +629,10 @@ def get_has_template(self) -> bool: Returns: True if has a template, False otherwise """ - if not self.get_figure(): + figure = self.get_figure() + if not figure: return self._has_template - return self.get_figure().get_has_template() + return figure.get_has_template() def get_has_color(self) -> bool: """ @@ -624,9 +641,10 @@ def get_has_color(self) -> bool: Returns: True if has color, False otherwise """ - if not self.get_figure(): + figure = self.get_figure() + if not figure: return self._has_color - return self.get_figure().get_has_color() + return figure.get_has_color() def get_has_subplots(self) -> bool: """ @@ -635,9 +653,10 @@ def get_has_subplots(self) -> bool: Returns: True if has subplots, False otherwise """ - if not self.get_figure(): + figure = self.get_figure() + if not figure: return self._has_subplots - return self.get_figure().get_has_subplots() + return figure.get_has_subplots() def __del__(self): self._liveness_scope.release() diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/RevisionManager.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/RevisionManager.py index f2172fad8..98493a0ae 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/RevisionManager.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/RevisionManager.py @@ -1,4 +1,5 @@ import threading +from typing import Any class RevisionManager: @@ -25,7 +26,8 @@ def __init__(self): def __enter__(self): self.lock.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + self.lock.release() def get_revision(self) -> int: diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 26a4bb5bc..7b174ef9b 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -3,7 +3,7 @@ from itertools import cycle, count from collections.abc import Generator from math import floor, ceil -from typing import Any, Callable +from typing import Any, Callable, Mapping, cast, Tuple from pandas import DataFrame from plotly.graph_objects import Figure @@ -362,7 +362,7 @@ def new_axis_generator( def attached_generator( arg: str, attached_cols: list[str] -) -> Generator[tuple[str, list], None, None]: +) -> Generator[tuple[str, list | None], None, None]: """Generate key, value pairs for error bar updates. If an error column is None, then there is no error bar drawn for the corresponding trace. @@ -393,7 +393,9 @@ def update_traces( Useful if marginals have been specified, as they should be skipped """ - for trace_index, update in zip(range(0, len(fig.data), step), generator): + for trace_index, update in zip( + range(0, len(cast(Tuple, fig.data)), step), generator + ): fig.update_traces(update, selector=trace_index) @@ -465,7 +467,7 @@ def sequence_generator( ls: str | list[str], map_: dict[str | tuple[str], str] | None = None, keys: list[tuple[str]] | None = None, -) -> Generator[tuple[str, str], None, None]: +) -> Generator[tuple[str, str | list[str]], None, None]: """Loops over the provided list to update the argument provided Args: @@ -598,12 +600,12 @@ def handle_custom_args( # Only update titles if dealing with a plot that has an axis sequence # specified as this should otherwise preserve plotly express behavior - x_axis_generators = [ + x_axis_generators: list[Generator[tuple[str, Any] | dict[str, Any], None, None]] = [ base_x_axis_generator( "xaxis_sequence" in custom_call_args and custom_call_args["xaxis_sequence"] ) ] - y_axis_generators = [ + y_axis_generators: list[Generator[tuple[str, Any] | dict[str, Any], None, None]] = [ base_y_axis_generator( "yaxis_sequence" in custom_call_args and custom_call_args["yaxis_sequence"] ) @@ -680,7 +682,7 @@ def handle_custom_args( return trace_generator -def get_list_var_info(data_cols: dict[str, str | list[str]]) -> set[str]: +def get_list_var_info(data_cols: Mapping[str, str | list[str]]) -> set[str]: """Extract the variable that is a list. Args: @@ -711,7 +713,7 @@ def get_list_var_info(data_cols: dict[str, str | list[str]]) -> set[str]: def relabel_columns( - labels: dict[str, str], + labels: dict[str, str] | None, hover_mapping: list[dict[str, str]], types: set[str], current_partition: dict[str, str], @@ -795,7 +797,7 @@ def hover_text_generator( Yields: A dictionary update """ - if "finance" in types: + if isinstance(types, set) and "finance" in types: # finance has no hover text currently (besides the default) while True: yield {} @@ -821,10 +823,10 @@ def hover_text_generator( def compute_labels( hover_mapping: list[dict[str, str]], - hist_val_name: str, + hist_val_name: str | None, # hover_data - todo, dependent on arrays supported in data mappings types: set[str], - labels: dict[str, str], + labels: dict[str, str] | None, current_partition: dict[str, str], ) -> None: """Compute the labels for this chart, relabling the axis and hovertext. @@ -847,7 +849,9 @@ def compute_labels( relabel_columns(labels, hover_mapping, types, current_partition) -def calculate_hist_labels(hist_val_name: str, current_mapping: dict[str, str]) -> None: +def calculate_hist_labels( + hist_val_name: str | None, current_mapping: dict[str, str] +) -> None: """Calculate the histogram labels Args: @@ -866,7 +870,7 @@ def calculate_hist_labels(hist_val_name: str, current_mapping: dict[str, str]) - def add_axis_titles( custom_call_args: dict[str, Any], hover_mapping: list[dict[str, str]], - hist_val_name: str, + hist_val_name: str | None, ) -> None: """Add axis titles. Generally, this only applies when there is a list variable @@ -902,7 +906,7 @@ def add_axis_titles( def create_hover_and_axis_titles( custom_call_args: dict[str, Any], - data_cols: dict[str, str], + data_cols: dict[str, str | list[str]], hover_mapping: list[dict[str, str]], ) -> Generator[dict[str, Any], None, None]: """Create hover text and axis titles. There are three main behaviors. @@ -938,7 +942,7 @@ def create_hover_and_axis_titles( labels = custom_call_args.get("labels", None) hist_val_name = custom_call_args.get("hist_val_name", None) - current_partition = custom_call_args.get("current_partition", None) + current_partition = custom_call_args.get("current_partition", {}) compute_labels(hover_mapping, hist_val_name, types, labels, current_partition) diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py index d219509a9..19265fe4b 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py @@ -1,7 +1,8 @@ from __future__ import annotations +import sys from collections.abc import Generator, Callable -from typing import Any +from typing import Any, cast, Tuple, Dict import plotly.express as px from pandas import DataFrame @@ -139,8 +140,8 @@ def __init__( self, args: dict[str, Any], draw_figure: Callable, - groups: set[str], - marg_args: dict[str, Any], + groups: set[str] | None, + marg_args: dict[str, Any] | None, marg_func: Callable, ): self.by = None @@ -155,19 +156,19 @@ def __init__( self.marginal_x = args.pop("marginal_x", None) self.marginal_y = args.pop("marginal_y", None) - self.marg_args = marg_args + self.marg_args = marg_args if marg_args else {} self.attach_marginals = marg_func self.marg_color = None self.args = args - self.groups = groups + self.groups = groups if groups else set() self.preprocessor = None self.set_long_mode_variables() self.convert_table_to_long_mode() self.key_column_table = None self.partitioned_table = self.process_partitions() self.draw_figure = draw_figure - self.constituents = None + self.constituents = [] def set_long_mode_variables(self) -> None: """ @@ -228,7 +229,7 @@ def convert_table_to_long_mode( args["table"] = self.to_long_mode(table, self.cols) - def is_by(self, arg: str, map_val: str | list[str] = None) -> None: + def is_by(self, arg: str, map_val: str | list[str] | None = None) -> None: """ Given that the specific arg is a by arg, prepare the arg depending on if it is attached or not @@ -264,7 +265,7 @@ def is_by(self, arg: str, map_val: str | list[str] = None) -> None: def handle_plot_by_arg( self, arg: str, val: str | list[str] - ) -> tuple[str, str | list[str]]: + ) -> tuple[str, str | list[str] | None]: """ Handle all args that are possibly plot bys. If the "val" is none and the "by" arg is specified, @@ -315,7 +316,8 @@ def handle_plot_by_arg( elif val: self.is_by(arg, args[map_name]) elif plot_by_cols and ( - args.get("color_discrete_sequence") or "color" in self.by_vars + args.get("color_discrete_sequence") + or (self.by_vars and "color" in self.by_vars) ): # this needs to be last as setting "color" in any sense will override if not self.args["color_discrete_sequence"]: @@ -335,7 +337,9 @@ def handle_plot_by_arg( pass elif val: self.is_by(arg) - elif plot_by_cols and (args.get("size_sequence") or "size" in self.by_vars): + elif plot_by_cols and ( + args.get("size_sequence") or (self.by_vars and "size" in self.by_vars) + ): if not self.args["size_sequence"]: self.args["size_sequence"] = STYLE_DEFAULTS[arg] args["size_by"] = plot_by_cols @@ -350,7 +354,9 @@ def handle_plot_by_arg( args[f"attached_{arg}"] = args.pop(arg) elif val: self.is_by(arg, args[map_name]) - elif plot_by_cols and (args.get(seq_name) or arg in self.by_vars): + elif plot_by_cols and ( + args.get(seq_name) or (self.by_vars and arg in self.by_vars) + ): if not seq: self.args[seq_name] = STYLE_DEFAULTS[arg] args[f"{arg}_by"] = plot_by_cols @@ -415,7 +421,9 @@ def process_partitions(self) -> Table | PartitionedTable: if partition_cols: if not partitioned_table: - partitioned_table = args["table"].partition_by(list(partition_cols)) + partitioned_table = cast(Table, args["table"]).partition_by( + list(partition_cols) + ) if not self.key_column_table: self.key_column_table = partitioned_table.table.drop_columns( "__CONSTITUENT__" @@ -466,7 +474,7 @@ def build_ternary_chain(self, cols: list[str]) -> str: ternary_string += f"{self.pivot_vars['variable']} == `{col}` ? {col} : " return ternary_string - def to_long_mode(self, table: Table, cols: list[str]) -> Table: + def to_long_mode(self, table: Table, cols: list[str] | None) -> Table: """ Convert a table to long mode. This will take the name of the columns, make a new "variable" column that contains the column names, and create @@ -480,6 +488,7 @@ def to_long_mode(self, table: Table, cols: list[str]) -> Table: The table converted to long mode """ + cols = cols if cols else [] new_tables = [] for col in cols: new_tables.append( @@ -500,10 +509,12 @@ def current_partition_generator(self) -> Generator[dict[str, str], None, None]: Yields: The partition dictionary mapping column to value """ + # the table is guaranteed to be a partitioned table here + key_columns = cast(PartitionedTable, self.partitioned_table).key_columns + # sort the columns so the order is consistent + key_columns.sort() + for table in self.constituents: - # sort the columns so the order is consistent - key_columns = self.partitioned_table.key_columns - key_columns.sort() key_column_table = dhpd.to_pandas(table.select(key_columns)) key_column_tuples = get_partition_key_column_tuples( @@ -533,11 +544,15 @@ def table_partition_generator( The tuple of table and current partition """ column = self.pivot_vars["value"] if self.pivot_vars else None - tables = self.preprocessor.preprocess_partitioned_tables( - self.constituents, column - ) - for table, current_partition in zip(tables, self.current_partition_generator()): - yield table, current_partition + if self.preprocessor: + tables = self.preprocessor.preprocess_partitioned_tables( + self.constituents, column + ) + for table, current_partition in zip( + tables, self.current_partition_generator() + ): + # since this is preprocessed it will always be a tuple + yield cast(Tuple[Table, Dict[str, str]], (table, current_partition)) def partition_generator(self) -> Generator[dict[str, Any], None, None]: """ @@ -554,7 +569,7 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]: # if a tuple is returned here, it was preprocessed already so pivots aren't needed table, arg_update = table args.update(arg_update) - elif self.pivot_vars and self.pivot_vars["value"]: + elif self.pivot_vars and self.pivot_vars["value"] and self.list_var: # there is a list of variables, so replace them with the combined column args[self.list_var] = self.pivot_vars["value"] @@ -566,11 +581,12 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]: "preprocess_hist" in self.groups or "preprocess_freq" in self.groups or "preprocess_time" in self.groups - ): + ) and self.preprocessor: # still need to preprocess the base table - table, arg_update = list( - self.preprocessor.preprocess_partitioned_tables([args["table"]]) - )[0] + table, arg_update = cast( + Tuple, + [*self.preprocessor.preprocess_partitioned_tables([args["table"]])][0], + ) args["table"] = table args.update(arg_update) yield args diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/_layer.py b/plugins/plotly-express/src/deephaven/plot/express/plots/_layer.py index f25157e79..eff6f569d 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/_layer.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/_layer.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import Any, Callable +from typing import Any, Callable, cast, Tuple from plotly.graph_objs import Figure @@ -13,6 +13,18 @@ # The function layer in this file is exempt from the styleguide rule that types should not be in the # description if there is a type annotation. +from typing import TypedDict + + +class LayerSpecDict(TypedDict, total=False): + x: list[float] | None + y: list[float] | None + xaxis_update: dict[str, Any] | None + yaxis_update: dict[str, Any] | None + wipe_layout: bool | None + matched_xaxis: str | int | None + matched_yaxis: str | int | None + def normalize_position( position: float, chart_start: float, chart_range: float @@ -60,7 +72,7 @@ def get_new_positions( return new_positions -def resize_domain(obj: dict, new_domain: dict[str, list[float]]) -> None: +def resize_domain(obj: dict, new_domain: LayerSpecDict) -> None: """Resize the domain of the given object Args: @@ -70,8 +82,8 @@ def resize_domain(obj: dict, new_domain: dict[str, list[float]]) -> None: Contains keys of x and y and values of domains, such as [0,0.5] """ - new_domain_x = new_domain.get("x", None) - new_domain_y = new_domain.get("y", None) + new_domain_x = new_domain.get("x") + new_domain_y = new_domain.get("y") obj_domain_x = obj["domain"]["x"] obj_domain_y = obj["domain"]["y"] domain_update = {} @@ -89,7 +101,7 @@ def resize_domain(obj: dict, new_domain: dict[str, list[float]]) -> None: pass -def resize_xy_axis(axis: dict, new_domain: dict[str, list[float]], which: str) -> None: +def resize_xy_axis(axis: dict, new_domain: LayerSpecDict, which: str) -> None: """Resize either an x or y axis. Args: @@ -99,11 +111,11 @@ def resize_xy_axis(axis: dict, new_domain: dict[str, list[float]], which: str) - which: Either "x" or "y" """ - new_domain_x = new_domain.get("x", None) - new_domain_y = new_domain.get("y", None) + new_domain_x = new_domain.get("x") + new_domain_y = new_domain.get("y") # the existing domain is assumed to be 0, 1 if not set axis_domain = axis.get("domain", [0, 1]) - axis_position = axis.get("position", None) + axis_position = axis.get("position") axis_update = {} try: if which == "x": @@ -172,7 +184,7 @@ def reassign_attributes(axis: dict, axes_remapping: dict[str, str]) -> None: def resize_axis( - type_: str, old_axis: str, axis: dict, num: str, new_domain: dict[str, list[float]] + type_: str, old_axis: str, axis: dict, num: str, new_domain: LayerSpecDict ) -> tuple[str, str, str]: """Maps the specified axis to new_domain and returns info to help remap axes @@ -202,7 +214,7 @@ def resize_axis( return new_axis, old_axis, new_axis -def get_axis_update(spec: dict[str, Any], type_: str) -> dict[str, Any] | None: +def get_axis_update(spec: LayerSpecDict, type_: str) -> dict[str, Any] | None: """Retrieve an axis update from the spec Args: @@ -222,7 +234,7 @@ def get_axis_update(spec: dict[str, Any], type_: str) -> dict[str, Any] | None: def match_axes( type_: str, - spec: dict[str, str | bool | list[float]], + spec: LayerSpecDict, matches_axes: dict[Any, dict[int, str]], axis_indices: dict[str, int], new_trace_axis: str, @@ -248,7 +260,7 @@ def match_axes( if there is a dictionary to match to """ - match_axis_key = spec.get(f"matched_{type_}", None) + match_axis_key = spec.get(f"matched_{type_}") axis_index = axis_indices.get(type_) if match_axis_key is not None: @@ -256,11 +268,15 @@ def match_axes( match_axis_key = (match_axis_key, type_) if match_axis_key not in matches_axes: matches_axes[match_axis_key] = {} - if not matches_axes[match_axis_key].get(axis_index, None): + if ( + matches_axes[match_axis_key] + and axis_index is not None + and not matches_axes[match_axis_key].get(axis_index) + ): # this is the base axis to match to, so matches is not added - matches_axes[match_axis_key][axis_index] = new_trace_axis return {} - return {"matches": matches_axes[match_axis_key][axis_index]} + if axis_index is not None: + return {"matches": matches_axes[match_axis_key][axis_index]} return {} @@ -268,7 +284,7 @@ def match_axes( def resize_fig( fig_data: dict, fig_layout: dict, - spec: dict[str, str | bool | list[float]], + spec: LayerSpecDict, new_axes_start: dict[str, int], matches_axes: dict[Any, dict[int, str]], ) -> tuple[dict, dict]: @@ -336,7 +352,7 @@ def resize_fig( update = get_axis_update(spec, type_) new_axis, old_trace_axis, new_trace_axis = resize_axis( - type_, name, obj, num, spec + type_, name, obj, str(num), spec ) matches_update = match_axes( @@ -375,8 +391,8 @@ def resize_fig( def fig_data_and_layout( fig: Figure, i: int, - specs: list[dict[str, str | bool | list[float]]], - which_layout: int, + specs: list[LayerSpecDict] | None, + which_layout: int | None, new_axes_start: dict[str, int], matches_axes: dict[Any, dict[int, str]], ) -> tuple[tuple | dict, dict]: @@ -417,13 +433,13 @@ def fig_data_and_layout( if which_layout is None or which_layout == i: fig_layout.update(fig.to_dict()["layout"]) - return fig.data, fig_layout + return cast(Tuple, fig.data), fig_layout def atomic_layer( *figs: DeephavenFigure | Figure, which_layout: int | None = None, - specs: list[dict[str, Any]] | None = None, + specs: list[LayerSpecDict] | None = None, unsafe_update_figure: Callable = default_callback, ) -> DeephavenFigure: """ @@ -473,8 +489,13 @@ def atomic_layer( raise NotImplementedError( "Cannot currently add figure with subplots as a subplot" ) + + plotly_fig = arg.get_plotly_fig() + if plotly_fig is None: + raise ValueError("Figure does not have a plotly figure, cannot layer") + fig_data, fig_layout = fig_data_and_layout( - arg.get_plotly_fig(), + plotly_fig, i, specs, which_layout, @@ -509,7 +530,7 @@ def atomic_layer( def layer( *figs: DeephavenFigure | Figure, which_layout: int | None = None, - specs: list[dict[str, Any]] | None = None, + specs: list[LayerSpecDict] | None = None, unsafe_update_figure: Callable = default_callback, ) -> DeephavenFigure: """Layers the provided figures. Be default, the layouts are sequentially @@ -531,9 +552,6 @@ def layer( Can also specify "matched_xaxis" or "matched_yaxis" to add this figure to a match group. All figures with the same value of this group will have matching axes. - atomic: bool: (Default value = False) If True, this layer call will be - treated as an atomic part of a figure creation call, and the figure will not be updated until - This should almost certainly always be False unsafe_update_figure: Callable: An update function that takes a plotly figure as an argument and optionally returns a plotly figure. If a figure is not returned, the plotly figure passed will be assumed to be the return value. diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/_private_utils.py b/plugins/plotly-express/src/deephaven/plot/express/plots/_private_utils.py index 513dfc911..f09fd065b 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/_private_utils.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/_private_utils.py @@ -93,7 +93,7 @@ def append_suffixes(args: list[str], suffixes: list[str], sync_dict: SyncDict) - sync_dict.d[f"{arg}_{suffix}"] = sync_dict.will_pop(arg) -def apply_args_groups(args: dict[str, Any], groups: set[str]) -> None: +def apply_args_groups(args: dict[str, Any], possible_groups: set[str] | None) -> None: """Transform args depending on groups Args: @@ -101,7 +101,9 @@ def apply_args_groups(args: dict[str, Any], groups: set[str]) -> None: groups: A set of groups used to transform the args """ - groups = groups if isinstance(groups, set) else {groups} + groups: set = ( + possible_groups if isinstance(possible_groups, set) else {possible_groups} + ) sync_dict = SyncDict(args) @@ -172,7 +174,7 @@ def create_deephaven_figure( pop: list[str] | None = None, remap: dict[str, str] | None = None, px_func: Callable = lambda: None, -) -> tuple[DeephavenFigure, Table | PartitionedTable, Table, dict[str, Any]]: +) -> tuple[DeephavenFigure, Table | PartitionedTable, Table | None, dict[str, Any]]: """Process the provided args Args: @@ -390,7 +392,7 @@ def shared_marginal( def shared_violin( - is_marginal=True, + is_marginal: bool = True, **args: Any, ) -> DeephavenFigure: """ @@ -412,7 +414,7 @@ def shared_violin( return shared_marginal(is_marginal, func, groups, **args) -def shared_box(is_marginal=True, **args: Any) -> DeephavenFigure: +def shared_box(is_marginal: bool = True, **args: Any) -> DeephavenFigure: """ Create a box figure @@ -432,7 +434,7 @@ def shared_box(is_marginal=True, **args: Any) -> DeephavenFigure: return shared_marginal(is_marginal, func, groups, **args) -def shared_strip(is_marginal=True, **args: Any) -> DeephavenFigure: +def shared_strip(is_marginal: bool = True, **args: Any) -> DeephavenFigure: """ Create a strip figure @@ -452,7 +454,7 @@ def shared_strip(is_marginal=True, **args: Any) -> DeephavenFigure: return shared_marginal(is_marginal, func, groups, **args) -def shared_histogram(is_marginal=True, **args: Any) -> DeephavenFigure: +def shared_histogram(is_marginal: bool = True, **args: Any) -> DeephavenFigure: """ Create a histogram figure @@ -518,11 +520,17 @@ def create_marginal(marginal: str, args: dict[str, Any], which: str) -> Deephave } fig_marg = marginal_map[marginal](**args) - fig_marg.get_plotly_fig().update_traces(showlegend=False) + + plotly_fig_marg = fig_marg.get_plotly_fig() + + if plotly_fig_marg is None: + raise ValueError("Plotly figure is None, cannot create marginal figure") + + plotly_fig_marg.update_traces(showlegend=False) if marginal == "rug": symbol = "line-ns-open" if which == "x" else "line-ew-open" - fig_marg.get_plotly_fig().update_traces(marker_symbol=symbol, jitter=0) + plotly_fig_marg.update_traces(marker_symbol=symbol, jitter=0) return fig_marg diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py b/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py index 7804a5ba6..48e089b6c 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py @@ -379,29 +379,7 @@ def _ecdf( Returns: """ - # todo: not fully implemented - line_shape = "hv" - # rangemode = "tozero" - - args = locals() - - validate_common_args(args) - - args["color_discrete_sequence_marker"] = args.pop("color_discrete_sequence") - - args.pop("lines") - args.pop("ecdfnorm") - args.pop("ecdfmode") - - update_wrapper = partial( - unsafe_figure_update_wrapper, args.pop("unsafe_update_figure") - ) - - create_layered = partial(preprocess_and_layer, preprocess_ecdf, px.line, args) - - return update_wrapper( - create_layered("x") if x else create_layered("y", orientation="h") - ) + raise NotImplementedError("ecdf is not yet implemented") def histogram( diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/line.py b/plugins/plotly-express/src/deephaven/plot/express/plots/line.py index a4080ddc9..e225f3fae 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/line.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/line.py @@ -60,7 +60,7 @@ def line( line_shape: str = "linear", title: str | None = None, template: str | None = None, - render_mode="svg", + render_mode: str = "svg", unsafe_update_figure: Callable = default_callback, ) -> DeephavenFigure: """Returns a line chart @@ -483,7 +483,7 @@ def line_polar( log_r: bool = False, title: str | None = None, template: str | None = None, - render_mode="svg", + render_mode: str = "svg", unsafe_update_figure: Callable = default_callback, ) -> DeephavenFigure: """Returns a polar scatter chart diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/maps.py b/plugins/plotly-express/src/deephaven/plot/express/plots/maps.py index f54bf868e..bb54a0b34 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/maps.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/maps.py @@ -49,7 +49,7 @@ def scatter_geo( projection: str | None = None, scope: str | None = None, center: dict[str, float] | None = None, - fitbounds: str = False, + fitbounds: bool | str = False, basemap_visible: bool | None = None, title: str | None = None, template: str | None = None, @@ -366,7 +366,7 @@ def line_geo( projection: str | None = None, scope: str | None = None, center: dict[str, float] | None = None, - fitbounds: str = False, + fitbounds: bool | str = False, basemap_visible: bool | None = None, title: str | None = None, template: str | None = None, diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/scatter.py b/plugins/plotly-express/src/deephaven/plot/express/plots/scatter.py index 8fa27569e..c2fa9e87b 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/scatter.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/scatter.py @@ -468,7 +468,7 @@ def scatter_polar( log_r: bool = False, title: str | None = None, template: str | None = None, - render_mode="webgl", + render_mode: str = "webgl", unsafe_update_figure: Callable = default_callback, ) -> DeephavenFigure: """Returns a polar scatter chart diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/subplots.py b/plugins/plotly-express/src/deephaven/plot/express/plots/subplots.py index 9cd8258f3..1d613ae7a 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/subplots.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/subplots.py @@ -1,17 +1,29 @@ from __future__ import annotations import math -from typing import Any +from typing import Any, TypeVar, List, cast, TypedDict from plotly.graph_objs import Figure -from ._layer import layer +from ._layer import layer, LayerSpecDict from .. import DeephavenFigure - # The function make_subplots in this file is exempt from the styleguide rule that types should not be in the # description if there is a type annotation. +# generic grid that is a list of lists of anything +T = TypeVar("T") +Grid = List[List[T]] + + +class SubplotSpecDict(TypedDict, total=False): + l: float + r: float + t: float + b: float + rowspan: int + colspan: int + def get_shared_key( row: int, @@ -41,14 +53,14 @@ def get_shared_key( def get_new_specs( - specs: list[list[dict[str, int | float]]], + specs: Grid[SubplotSpecDict] | None, row_starts: list[float], row_ends: list[float], col_starts: list[float], col_ends: list[float], - shared_xaxes: str | bool, - shared_yaxes: str | bool, -) -> list[dict[str, list[float] | int]]: + shared_xaxes: str | bool | None, + shared_yaxes: str | bool | None, +) -> list[LayerSpecDict]: """Transforms the given specs and row and column lists to specs for layering Args: @@ -89,11 +101,14 @@ def get_new_specs( r = spec.get("r", 0) t = spec.get("t", 0) b = spec.get("b", 0) - rowspan = spec.get("rowspan", 1) - colspan = spec.get("colspan", 1) + rowspan: int = int(spec.get("rowspan", 1)) + colspan: int = int(spec.get("colspan", 1)) y_1 = row_ends[row + rowspan - 1] x_1 = col_ends[col + colspan - 1] - new_spec = {"x": [x_0 + l, x_1 - r], "y": [y_0 + t, y_1 - b]} + new_spec: LayerSpecDict = { + "x": [x_0 + l, x_1 - r], + "y": [y_0 + t, y_1 - b], + } if ( shared_xaxes @@ -112,9 +127,7 @@ def get_new_specs( return new_specs -def make_grid( - items: list[Any], rows: int, cols: int, fill: Any = None -) -> list[list[Any]]: +def make_grid(items: list[T], rows: int, cols: int, fill: Any = None) -> Grid[T]: """Make a grid (list of lists) out of the provided items Args: @@ -166,7 +179,7 @@ def get_domains(values: list[float], spacing: float) -> tuple[list[float], list[ scaled = [v * scale for v in values] # the first start value is just 0 since there is no spacing preceeding it - starts = [0] + starts = [0.0] # ignore the last value as it is not needed for the start of any domain for i in range(len(scaled) - 1): starts.append(starts[-1] + scaled[i] + spacing) @@ -180,20 +193,35 @@ def get_domains(values: list[float], spacing: float) -> tuple[list[float], list[ return starts, ends +def is_grid(specs: list[SubplotSpecDict] | Grid[SubplotSpecDict]) -> bool: + """Check if the given specs is a grid + + Args: + specs: + The specs to check + + Returns: + True if the specs is a grid, False otherwise + + """ + list_count = sum(isinstance(spec, list) for spec in specs) + if 0 < list_count < len(specs): + raise ValueError("Specs is a mix of lists and non-lists") + return list_count == len(specs) and list_count > 0 + + def make_subplots( *figs: Figure | DeephavenFigure, - rows: int | None = None, - cols: int | None = None, - shared_xaxes: bool | int | None = None, - shared_yaxes: bool | int | None = None, - grid: list[list[Figure | DeephavenFigure]] | None = None, + rows: int = 0, + cols: int = 0, + shared_xaxes: str | bool | None = None, + shared_yaxes: str | bool | None = None, + grid: Grid[Figure | DeephavenFigure] | None = None, horizontal_spacing: float | None = None, vertical_spacing: float | None = None, column_widths: list[float] | None = None, row_heights: list[float] | None = None, - specs: list[dict[str, int | float]] - | list[list[dict[str, int | float]]] - | None = None, + specs: list[SubplotSpecDict] | Grid[SubplotSpecDict] | None = None, ) -> DeephavenFigure: """Create subplots. Either figs and at least one of rows and cols or grid should be passed. @@ -228,7 +256,7 @@ def make_subplots( The widths of each column. Should sum to 1. row_heights: list[float] | None: (Default value = None) The heights of each row. Should sum to 1. - specs: list[dict[str, int | float]] | list[list[dict[str, int | float]]] | None: + specs: list[SubplotSpecDict] | Grid[SubplotSpecDict] | None: (Default value = None) A list or grid of dicts that contain specs. An empty dictionary represents no specs, and None represents no figure, either @@ -258,13 +286,16 @@ def make_subplots( rows, cols = len(grid), len(grid[0]) # only transform specs into a grid when dimensions of figure grid are known - if specs: - specs = ( - specs - if isinstance(specs[0], list) - else make_grid(specs, rows, cols, fill={}) - ) - specs.reverse() + spec_grid: Grid[SubplotSpecDict] | None = None + if specs and isinstance(specs, list): + if is_grid(specs): + spec_grid = cast(Grid[Any], specs) + else: + specs = cast(List[SubplotSpecDict], specs) + spec_grid = cast(Grid[Any], make_grid(specs, rows, cols, fill={})) + spec_grid.reverse() + elif specs: + raise ValueError("specs must be a list or a grid") # same defaults as plotly if horizontal_spacing is None: @@ -287,7 +318,7 @@ def make_subplots( return layer( *[fig for fig_row in grid for fig in fig_row], specs=get_new_specs( - specs, + spec_grid, row_starts, row_ends, col_starts, diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/AttachedPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/AttachedPreprocessor.py index cc749ec40..f8776a9a7 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/AttachedPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/AttachedPreprocessor.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + from .StyleManager import StyleManager from ..shared import get_unique_names @@ -9,14 +11,17 @@ class AttachedPreprocessor: such as treemap and pie. Attributes: - args: dict[str, Any]: Args used to create the plot - always_attached: dict[tuple[str, str], - tuple[dict[str, str], list[str], str]: The dict mapping the arg and column + args: Args used to create the plot + always_attached: The dict mapping the arg and column to the style map, dictionary, and new column name, to be used for AttachedProcessor when dealing with an "always_attached" plot """ - def __init__(self, args, always_attached): + def __init__( + self, + args: dict[str, Any], + always_attached: dict[tuple[str, str], tuple[dict[str, str], list[str], str]], + ): self.args = args self.always_attached = always_attached self.prepare_preprocess() diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/FreqPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/FreqPreprocessor.py index caac4526a..847d1360e 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/FreqPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/FreqPreprocessor.py @@ -21,7 +21,7 @@ def __init__(self, args: dict[str, Any]): def preprocess_partitioned_tables( self, tables: list[Table], column: str | None = None - ) -> Generator[tuple[Table, dict[str, str]], None, None]: + ) -> Generator[tuple[Table, dict[str, str | None]], None, None]: """Preprocess frequency bar params into an appropriate table This just sums each value by count diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py index 3b54febc5..a285185c8 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py @@ -120,8 +120,8 @@ def create_range_table(self) -> Table: ).view(self.names["range"]) def create_count_tables( - self, tables: list[Table], column: str = None - ) -> Generator[tuple[Table, dict[str, str]], None, None]: + self, tables: list[Table], column: str | None = None + ) -> Generator[tuple[Table, str], None, None]: """ Create count tables that aggregate up values. @@ -135,6 +135,8 @@ def create_count_tables( """ range_index, range_ = self.names["range_index"], self.names["range"] agg_func = HISTFUNC_MAP[self.histfunc] + if not self.range_table: + raise ValueError("Range table not created") for i, table in enumerate(tables): # the column needs to be temporarily renamed to avoid collisions tmp_name = f"tmp{i}" @@ -151,7 +153,7 @@ def create_count_tables( def preprocess_partitioned_tables( self, tables: list[Table], column: str | None = None - ) -> Generator[tuple[Table, dict[str, str]], None, None]: + ) -> Generator[tuple[Table, dict[str, str | None]], None, None]: """ Preprocess tables into histogram tables @@ -187,6 +189,9 @@ def preprocess_partitioned_tables( var_axis_name = self.names[self.histfunc] + if not self.range_table: + raise ValueError("Range table not created") + bin_counts = bin_counts.join(self.range_table).update_view( [ f"{bin_min} = {range_}.binMin({range_index})", diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py index b7041e40a..808c5a75f 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py @@ -56,8 +56,12 @@ def prepare_preprocess(self) -> None: self.preprocesser = TimePreprocessor(self.args) def preprocess_partitioned_tables( - self, tables: list[Table], column: str | None = None - ) -> Generator[Table, None, None]: + self, tables: list[Table] | None, column: str | None = None + ) -> Generator[ + Table | tuple[Table, dict[str, str | None]] | tuple[Table, dict[str, str]], + None, + None, + ]: """ Preprocess the passed table, depending on the type of preprocessor used @@ -68,6 +72,7 @@ def preprocess_partitioned_tables( Yields: Table: the preprocessed table """ + tables = tables or [] if self.preprocesser: yield from self.preprocesser.preprocess_partitioned_tables(tables, column) else: diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/UnivariatePreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/UnivariatePreprocessor.py index ec444219c..4c13574b3 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/UnivariatePreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/UnivariatePreprocessor.py @@ -24,8 +24,8 @@ class UnivariatePreprocessor: def __init__(self, args: dict[str, Any], pivot_vars: dict[str, str] | None = None): self.args = args self.table = args["table"] - self.var = "x" if args.get("x", None) else "y" + self.var = "x" if args.get("x") else "y" self.other_var = "y" if self.var == "x" else "x" self.args["orientation"] = "h" if self.var == "y" else "v" - self.col_val = pivot_vars["value"] if pivot_vars else args[self.var] + self.col_val: str = pivot_vars["value"] if pivot_vars else args[self.var] self.cols = self.col_val if isinstance(self.col_val, list) else [self.col_val] diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/preprocess.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/preprocess.py index 3c5758266..c1c2711cb 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/preprocess.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/preprocess.py @@ -1,10 +1,11 @@ from __future__ import annotations from deephaven import agg +from deephaven.table import Table from deephaven.updateby import cum_sum -def preprocess_ecdf(table, column): +def preprocess_ecdf(table: Table, column: str) -> Table: """ Args: @@ -39,4 +40,4 @@ def preprocess_ecdf(table, column): .ungroup([column, prob_col]) ) - return probabilities, column, prob_col + return probabilities, column, prob_col # type: ignore diff --git a/plugins/plotly-express/src/deephaven/plot/express/shared/shared.py b/plugins/plotly-express/src/deephaven/plot/express/shared/shared.py index 6baf86f01..ea42fe677 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/shared/shared.py +++ b/plugins/plotly-express/src/deephaven/plot/express/shared/shared.py @@ -8,7 +8,8 @@ def combined_generator( - generators: list[Generator[tuple | dict, None, None]], fill: Any = None + generators: list[Generator[tuple[str, Any] | dict[str, Any], None, None]], + fill: Any = None, ) -> Generator[dict, None, None]: """Combines generators into dictionary updates iteratively One yield of this combined generator yields one yield from each dictionary, diff --git a/plugins/ui/src/deephaven/ui/_internal/RenderContext.py b/plugins/ui/src/deephaven/ui/_internal/RenderContext.py index 6e697fa4f..ea4a9a4d2 100644 --- a/plugins/ui/src/deephaven/ui/_internal/RenderContext.py +++ b/plugins/ui/src/deephaven/ui/_internal/RenderContext.py @@ -2,11 +2,21 @@ import threading import logging +from typing import ( + Any, + Callable, + Optional, + TypeVar, + Union, + Generator, + Generic, + cast, + Set, +) from functools import partial -from typing import Any, Callable, Optional, TypeVar, Union, Generic from deephaven import DHError from deephaven.liveness_scope import LivenessScope -from contextlib import AbstractContextManager, contextmanager +from contextlib import contextmanager from dataclasses import dataclass logger = logging.getLogger(__name__) @@ -52,7 +62,9 @@ class ValueWithLiveness(Generic[T]): liveness_scope: Union[LivenessScope, None] -def _value_or_call(value: T | None | Callable[[], T | None]) -> ValueWithLiveness[T]: +def _value_or_call( + value: T | None | Callable[[], T | None] +) -> ValueWithLiveness[T | None]: """ Creates a wrapper around the value, or invokes a callable to hold the value and the liveness scope creates while obtaining that value. @@ -170,7 +182,7 @@ def __del__(self): scope.release() @contextmanager - def open(self) -> AbstractContextManager: + def open(self) -> Generator[RenderContext, None, None]: """ Opens this context to track hook creation, sets this context as active on this thread, and opens the liveness scope for user-created objects. @@ -261,6 +273,10 @@ def get_state(self, key: StateKey) -> Any: self.manage(wrapper.liveness_scope) else: try: + if self._top_level_scope is None: + raise RuntimeError( + "RenderContext.get_state() called when RenderContext not opened" + ) self._top_level_scope.manage(wrapper.value) except DHError: # Ignore, we just won't manage this instance @@ -337,4 +353,4 @@ def manage(self, liveness_scope: LivenessScope) -> None: liveness_scope: the new LivenessScope to track """ assert self is get_context() - self._collected_scopes.add(liveness_scope.j_scope) + self._collected_scopes.add(cast(LivenessScope, liveness_scope.j_scope)) diff --git a/plugins/ui/src/deephaven/ui/_internal/utils.py b/plugins/ui/src/deephaven/ui/_internal/utils.py index f24567fe1..abbcf5c4a 100644 --- a/plugins/ui/src/deephaven/ui/_internal/utils.py +++ b/plugins/ui/src/deephaven/ui/_internal/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable +from typing import Any, Callable, Set, cast from inspect import signature import sys from functools import partial @@ -152,24 +152,27 @@ def wrap_callable(func: Callable) -> Callable: """ try: if sys.version_info.major == 3 and sys.version_info.minor >= 10: - sig = signature(func, eval_str=True) + sig = signature(func, eval_str=True) # type: ignore else: sig = signature(func) - max_args = 0 - kwargs_set = set() + max_args: int | None = 0 + kwargs_set: Set | None = set() for param in sig.parameters.values(): if param.kind == param.POSITIONAL_ONLY: + max_args = cast(int, max_args) max_args += 1 elif param.kind == param.POSITIONAL_OR_KEYWORD: # Don't know until runtime whether this will be passed as a positional or keyword arg + max_args = cast(int, max_args) + kwargs_set = cast(Set, kwargs_set) max_args += 1 kwargs_set.add(param.name) elif param.kind == param.VAR_POSITIONAL: - # There are no positional args after this so max can be safely set to None max_args = None elif param.kind == param.KEYWORD_ONLY: + kwargs_set = cast(Set, kwargs_set) kwargs_set.add(param.name) elif param.kind == param.VAR_KEYWORD: kwargs_set = None diff --git a/plugins/ui/src/deephaven/ui/components/html.py b/plugins/ui/src/deephaven/ui/components/html.py index 24daf20fb..8900600d5 100644 --- a/plugins/ui/src/deephaven/ui/components/html.py +++ b/plugins/ui/src/deephaven/ui/components/html.py @@ -6,7 +6,7 @@ from ..elements import BaseElement -def html_element(tag, *children, **attributes): +def html_element(tag: str, *children, **attributes): """ Create a new HTML element. Render just returns the children that are passed in. diff --git a/plugins/ui/src/deephaven/ui/components/icon.py b/plugins/ui/src/deephaven/ui/components/icon.py index babcff348..18c43c571 100644 --- a/plugins/ui/src/deephaven/ui/components/icon.py +++ b/plugins/ui/src/deephaven/ui/components/icon.py @@ -1,7 +1,7 @@ from ..elements import BaseElement -def icon(name, *children, **props): +def icon(name: str, *children, **props): """ Get a Deephaven icon by name. """ diff --git a/plugins/ui/src/deephaven/ui/elements/UITable.py b/plugins/ui/src/deephaven/ui/elements/UITable.py index 19f3fe4b2..45dd51a5c 100644 --- a/plugins/ui/src/deephaven/ui/elements/UITable.py +++ b/plugins/ui/src/deephaven/ui/elements/UITable.py @@ -1,7 +1,8 @@ from __future__ import annotations +import collections import logging -from typing import Any, Callable, Literal, Sequence, Optional +from typing import Callable, Literal, Sequence, Any, cast from deephaven.table import Table from deephaven import SortDirection from .Element import Element @@ -18,17 +19,17 @@ DataBarAxis, DataBarValuePlacement, DataBarDirection, - RowIndex, - RowDataMap, SelectionMode, TableSortDirection, + RowPressCallback, + StringSortDirection, ) from .._internal import dict_to_camel_case, RenderContext logger = logging.getLogger(__name__) -def remap_sort_direction(direction: TableSortDirection) -> Literal["ASC", "DESC"]: +def remap_sort_direction(direction: TableSortDirection | str) -> Literal["ASC", "DESC"]: """ Remap the sort direction to the grid sort direction @@ -43,7 +44,7 @@ def remap_sort_direction(direction: TableSortDirection) -> Literal["ASC", "DESC" elif direction == SortDirection.DESCENDING: return "DESC" elif direction in {"ASC", "DESC"}: - return direction + return cast(StringSortDirection, direction) raise ValueError(f"Invalid table sort direction: {direction}") @@ -394,9 +395,7 @@ def hide_columns(self, columns: str | list[str]) -> "UITable": """ raise NotImplementedError() - def on_row_press( - self, callback: Callable[[RowIndex, RowDataMap], None] - ) -> "UITable": + def on_row_press(self, callback: RowPressCallback) -> "UITable": """ Add a callback for when a press on a row is released (e.g. a row is clicked). @@ -410,9 +409,7 @@ def on_row_press( """ raise NotImplementedError() - def on_row_double_press( - self, callback: Callable[[RowIndex, RowDataMap], None] - ) -> "UITable": + def on_row_double_press(self, callback: RowPressCallback) -> "UITable": """ Add a callback for when a row is double clicked. @@ -475,13 +472,16 @@ def sort( """ direction_list: Sequence[TableSortDirection] = [] if direction: - direction_list = direction if isinstance(direction, list) else [direction] + direction_list_unmapped = ( + direction if isinstance(direction, Sequence) else [direction] + ) + # map deephaven sort direction to frontend sort direction direction_list = [ - remap_sort_direction(direction) for direction in direction_list + remap_sort_direction(direction) for direction in direction_list_unmapped ] - by_list = by if isinstance(by, list) else [by] + by_list = by if isinstance(by, Sequence) else [by] if direction and len(direction_list) != len(by_list): raise ValueError("by and direction must be the same length") diff --git a/plugins/ui/src/deephaven/ui/hooks/use_callback.py b/plugins/ui/src/deephaven/ui/hooks/use_callback.py index 653a7bc69..7496e1205 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_callback.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_callback.py @@ -1,7 +1,11 @@ -from .use_ref import use_ref +from __future__ import annotations +from typing import Callable, Any, Sequence -def use_callback(func, dependencies): +from .use_ref import use_ref, Ref + + +def use_callback(func: Callable, dependencies: set[Any] | Sequence[Any]) -> Callable: """ Create a stable handle for a callback function. The callback will only be recreated if the dependencies change. @@ -12,7 +16,7 @@ def use_callback(func, dependencies): Returns: The stable handle to the callback function. """ - deps_ref = use_ref(None) + deps_ref: Ref[set[Any] | Sequence[Any] | None] = use_ref(None) callback_ref = use_ref(lambda: None) stable_callback_ref = use_ref( lambda *args, **kwargs: callback_ref.current(*args, **kwargs) diff --git a/plugins/ui/src/deephaven/ui/hooks/use_cell_data.py b/plugins/ui/src/deephaven/ui/hooks/use_cell_data.py index 07dde4a82..967ba5120 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_cell_data.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_cell_data.py @@ -9,7 +9,7 @@ from ..types import Sentinel -def _cell_data(data: pd.DataFrame, is_sentinel: bool) -> None: +def _cell_data(data: pd.DataFrame | Sentinel, is_sentinel: bool) -> Any | Sentinel: """ Return the first cell of the table. diff --git a/plugins/ui/src/deephaven/ui/hooks/use_column_data.py b/plugins/ui/src/deephaven/ui/hooks/use_column_data.py index a29d73086..c7522437a 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_column_data.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_column_data.py @@ -8,7 +8,9 @@ from ..types import Sentinel, ColumnData -def _column_data(data: pd.DataFrame, is_sentinel: bool) -> ColumnData: +def _column_data( + data: pd.DataFrame | Sentinel, is_sentinel: bool +) -> ColumnData | Sentinel: """ Return the first column of the table as a list. diff --git a/plugins/ui/src/deephaven/ui/hooks/use_effect.py b/plugins/ui/src/deephaven/ui/hooks/use_effect.py index 335432810..713c3324e 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_effect.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_effect.py @@ -1,9 +1,12 @@ -from .use_ref import use_ref +from __future__ import annotations + +from typing import Callable, Any, cast, Sequence +from .use_ref import use_ref, Ref from deephaven.liveness_scope import LivenessScope from .._internal import get_context -def use_effect(func, dependencies): +def use_effect(func: Callable[[], Any], dependencies: set[Any] | Sequence[Any]): """ Call a function when the dependencies change. Optionally return a cleanup function to be called when dependencies change again or component is unmounted. @@ -14,9 +17,9 @@ def use_effect(func, dependencies): Returns: None """ - deps_ref = use_ref(None) + deps_ref: Ref[set[Any] | Sequence[Any] | None] = use_ref(None) cleanup_ref = use_ref(lambda: None) - scope_ref = use_ref(None) + scope_ref: Ref[LivenessScope | None] = use_ref(None) # Check if the dependencies have changed if deps_ref.current != dependencies: @@ -37,4 +40,4 @@ def use_effect(func, dependencies): deps_ref.current = dependencies # Whether new or existing, continue to retain the liveness scope from the most recently invoked effect. - get_context().manage(scope_ref.current) + get_context().manage(cast(LivenessScope, scope_ref.current)) diff --git a/plugins/ui/src/deephaven/ui/hooks/use_execution_context.py b/plugins/ui/src/deephaven/ui/hooks/use_execution_context.py index deec01f8e..93f1c9045 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_execution_context.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_execution_context.py @@ -24,7 +24,7 @@ def func_with_ctx( def use_execution_context( - exec_ctx: ExecutionContext = None, + exec_ctx: ExecutionContext | None = None, ) -> Callable[[Callable], None]: """ Create an execution context wrapper for a function. @@ -35,6 +35,6 @@ def use_execution_context( Returns: A callable that will take any callable and invoke it within the current exec context """ - exec_ctx = use_memo(lambda: exec_ctx if exec_ctx else get_exec_ctx(), [exec_ctx]) - exec_fn = use_memo(lambda: partial(func_with_ctx, exec_ctx), [exec_ctx]) + exec_ctx = use_memo(lambda: exec_ctx if exec_ctx else get_exec_ctx(), {exec_ctx}) + exec_fn = use_memo(lambda: partial(func_with_ctx, exec_ctx), {exec_ctx}) return exec_fn diff --git a/plugins/ui/src/deephaven/ui/hooks/use_liveness_scope.py b/plugins/ui/src/deephaven/ui/hooks/use_liveness_scope.py index 8cc404684..8c4233df4 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_liveness_scope.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_liveness_scope.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from .._internal import get_context -from .use_ref import use_ref +from .use_ref import use_ref, Ref from typing import Callable from deephaven.liveness_scope import LivenessScope @@ -17,7 +19,7 @@ def use_liveness_scope(func: Callable) -> Callable: Returns: The wrapped Callable """ - scope_ref = use_ref(None) + scope_ref: Ref[LivenessScope | None] = use_ref(None) # If the value is set, the wrapped callable was invoked since we were last called - the current render # cycle now will manage it, and release it when appropriate. diff --git a/plugins/ui/src/deephaven/ui/hooks/use_memo.py b/plugins/ui/src/deephaven/ui/hooks/use_memo.py index 1c7ec74aa..39b384449 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_memo.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_memo.py @@ -2,13 +2,13 @@ from .use_ref import use_ref, Ref from .._internal import ValueWithLiveness, get_context -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TypeVar, cast, Union, Sequence from deephaven.liveness_scope import LivenessScope T = TypeVar("T") -def use_memo(func: Callable[[], T], dependencies: set[Any]) -> T: +def use_memo(func: Callable[[], T], dependencies: set[Any] | Sequence[Any]) -> T: """ Memoize the result of a function call. The function will only be called again if the dependencies change. @@ -19,9 +19,9 @@ def use_memo(func: Callable[[], T], dependencies: set[Any]) -> T: Returns: The memoized result of the function call. """ - deps_ref: Ref[set[Any] | None] = use_ref(None) + deps_ref: Ref[set[Any] | Sequence[Any] | None] = use_ref(None) value_ref: Ref[ValueWithLiveness[T | None]] = use_ref( - ValueWithLiveness(value=None, liveness_scope=None) + ValueWithLiveness(value=cast(Union[T, None], None), liveness_scope=None) ) if deps_ref.current != dependencies: @@ -33,7 +33,7 @@ def use_memo(func: Callable[[], T], dependencies: set[Any]) -> T: ) # The current RenderContext will then own the newly created liveness scope, and release when appropriate. - get_context().manage(value_ref.current.liveness_scope) + get_context().manage(cast(LivenessScope, value_ref.current.liveness_scope)) deps_ref.current = dependencies elif value_ref.current.liveness_scope: diff --git a/plugins/ui/src/deephaven/ui/hooks/use_ref.py b/plugins/ui/src/deephaven/ui/hooks/use_ref.py index 70e9f55d7..0b3ff97f1 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_ref.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_ref.py @@ -1,7 +1,7 @@ from __future__ import annotations from .use_state import use_state -from typing import Generic, overload, TypeVar +from typing import Generic, overload, TypeVar, Optional T = TypeVar("T") @@ -23,7 +23,12 @@ def use_ref(initial_value: T) -> Ref[T]: ... -def use_ref(initial_value: T | None = None) -> Ref[T | None]: +@overload +def use_ref(initial_value: T | None) -> Ref[T | None]: + ... + + +def use_ref(initial_value: T | None = None) -> Ref[T] | Ref[T | None]: """ Store a reference to a value that will persist across renders. diff --git a/plugins/ui/src/deephaven/ui/hooks/use_row_data.py b/plugins/ui/src/deephaven/ui/hooks/use_row_data.py index 75d8df004..c07d33e79 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_row_data.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_row_data.py @@ -8,7 +8,7 @@ from ..types import Sentinel, RowData -def _row_data(data: pd.DataFrame, is_sentinel: bool) -> RowData: +def _row_data(data: pd.DataFrame | Sentinel, is_sentinel: bool) -> RowData | Sentinel: """ Return the first row of the table as a dictionary. diff --git a/plugins/ui/src/deephaven/ui/hooks/use_row_list.py b/plugins/ui/src/deephaven/ui/hooks/use_row_list.py index e4e363053..e303028f3 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_row_list.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_row_list.py @@ -9,7 +9,7 @@ from ..types import Sentinel -def _row_list(data: pd.DataFrame, is_sentinel: bool) -> list[Any]: +def _row_list(data: pd.DataFrame | Sentinel, is_sentinel: bool) -> list[Any] | Sentinel: """ Return the first row of the table as a list. diff --git a/plugins/ui/src/deephaven/ui/hooks/use_table_data.py b/plugins/ui/src/deephaven/ui/hooks/use_table_data.py index 5041126e3..e4dc1dce0 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_table_data.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_table_data.py @@ -90,7 +90,9 @@ def _set_new_data( set_is_sentinel(new_is_sentinel) -def _table_data(data: pd.DataFrame, is_sentinel: bool) -> TableData: +def _table_data( + data: pd.DataFrame | Sentinel, is_sentinel: bool +) -> TableData | Sentinel: """ Returns the table as a dictionary. diff --git a/plugins/ui/src/deephaven/ui/hooks/use_table_listener.py b/plugins/ui/src/deephaven/ui/hooks/use_table_listener.py index 13bd79715..a1ba05929 100644 --- a/plugins/ui/src/deephaven/ui/hooks/use_table_listener.py +++ b/plugins/ui/src/deephaven/ui/hooks/use_table_listener.py @@ -47,7 +47,7 @@ def with_ctx( def wrap_listener( listener: Callable[[TableUpdate, bool], None] | TableListener -) -> Callable[[TableUpdate, bool], None] | None: +) -> Callable[[TableUpdate, bool], None]: """ Wrap the listener in an execution context. @@ -61,7 +61,7 @@ def wrap_listener( return with_ctx(listener.on_update) elif callable(listener): return with_ctx(listener) - return None + raise ValueError("Listener must be a function or a TableListener") def use_table_listener( @@ -98,11 +98,11 @@ def start_listener() -> Callable[[], None]: handle = listen( table, wrap_listener(listener), - description=description, + description=description, # type: ignore # missing Optional type do_replay=do_replay, replay_lock=replay_lock, ) return lambda: handle.stop() - use_effect(start_listener, [table, listener, description, do_replay, replay_lock]) + use_effect(start_listener, {table, listener, description, do_replay, replay_lock}) diff --git a/plugins/ui/src/deephaven/ui/object_types/DashboardType.py b/plugins/ui/src/deephaven/ui/object_types/DashboardType.py index 371e7d5e7..622aa8cc1 100644 --- a/plugins/ui/src/deephaven/ui/object_types/DashboardType.py +++ b/plugins/ui/src/deephaven/ui/object_types/DashboardType.py @@ -1,6 +1,6 @@ +from typing import Any + from ..elements import DashboardElement -from .._internal import get_component_name -from .ElementMessageStream import ElementMessageStream from .ElementType import ElementType @@ -13,5 +13,5 @@ class DashboardType(ElementType): def name(self) -> str: return "deephaven.ui.Dashboard" - def is_type(self, obj: any) -> bool: + def is_type(self, obj: Any) -> bool: return isinstance(obj, DashboardElement) diff --git a/plugins/ui/src/deephaven/ui/types/types.py b/plugins/ui/src/deephaven/ui/types/types.py index 9aae04738..f727104f5 100644 --- a/plugins/ui/src/deephaven/ui/types/types.py +++ b/plugins/ui/src/deephaven/ui/types/types.py @@ -1,8 +1,9 @@ -from typing import Any, Dict, Literal, Union, List, Tuple, Optional +from typing import Any, Dict, Literal, Union, List, Tuple, Optional, Callable from deephaven import SortDirection RowIndex = Optional[int] RowDataMap = Dict[str, Any] +RowPressCallback = Callable[[RowIndex, RowDataMap], None] AggregationOperation = Literal[ "COUNT", "COUNT_DISTINCT", @@ -37,11 +38,12 @@ QuickFilterExpression = str RowData = Dict[ColumnName, Any] # A RowIndex of None indicates a header column -RowIndex = Union[int, None] +RowIndex = Optional[int] ColumnData = List[Any] TableData = Dict[ColumnName, ColumnData] SearchMode = Literal["SHOW", "HIDE", "DEFAULT"] SelectionMode = Literal["CELL", "ROW", "COLUMN"] Sentinel = Any TransformedData = Any -TableSortDirection = Union[Literal["ASC", "DESC"], SortDirection] +StringSortDirection = Literal["ASC", "DESC"] +TableSortDirection = Union[StringSortDirection, SortDirection] diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..96bed7376 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,5 @@ +[lint] +select = ["ANN001"] + +[lint.per-file-ignores] +"**/{test,matplotlib,json,plotly}/*" = ["ANN001"] \ No newline at end of file