From b8a512a9f4cb1df97e669950f8d4821a0172377e Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 18 Dec 2024 13:41:31 -0600 Subject: [PATCH] wip --- .../plot/express/deephaven_figure/generate.py | 57 +++++++++---------- .../plot/express/plots/PartitionManager.py | 54 +++++++++++------- .../plot/express/plots/_private_utils.py | 12 ++-- .../express/preprocess/HistPreprocessor.py | 9 +-- .../plot/express/preprocess/Preprocessor.py | 14 ++--- .../preprocess/UnivariateAwarePreprocessor.py | 45 +++++++++++---- 6 files changed, 113 insertions(+), 78 deletions(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 4e4ec36a2..8e8ab45d7 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -111,9 +111,9 @@ "current_col", "current_var", "labels", - "hist_agg_label_h", - "hist_agg_label_v", - "pivot_vars", + "hist_agg_label", + "hist_orientation", + "stacked_column_names", "current_partition", "colors", "unsafe_update_figure", @@ -685,7 +685,7 @@ def handle_custom_args( return trace_generator -def get_list_var_info(data_cols: Mapping[str, str | list[str]]) -> set[str]: +def get_list_param_info(data_cols: Mapping[str, str | list[str]]) -> set[str]: """Extract the variable that is a list. Args: @@ -825,8 +825,8 @@ def hover_text_generator( def compute_labels( hover_mapping: list[dict[str, str]], - hist_agg_label_h: str | None, - hist_agg_label_v: str | None, + hist_agg_label: str | None, + hist_orientation: str | None, heatmap_agg_label: str | None, # hover_data - todo, dependent on arrays supported in data mappings types: set[str], @@ -839,8 +839,8 @@ def compute_labels( Args: hover_mapping: The mapping of variables to columns - hist_agg_label_h: The histogram agg label when oriented horizontally - hist_agg_label_v: The histogram agg label when oriented vertically + hist_agg_label: The histogram agg label + hist_orientation: The histogram orientation heatmap_agg_label: The aggregate density heatmap column title types: Any types of this chart that require special processing labels: A dictionary of old column name to new column name mappings @@ -850,7 +850,7 @@ def compute_labels( the renamed current_col """ - calculate_hist_labels(hist_agg_label_h, hist_agg_label_v, hover_mapping[0]) + calculate_hist_labels(hist_agg_label, hist_orientation, hover_mapping[0]) calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0], labels) @@ -883,31 +883,30 @@ def calculate_density_heatmap_labels( def calculate_hist_labels( - hist_agg_label_h: str | None, - hist_agg_label_v: str | None, + hist_agg_label: str | None, + hist_orientation: str | None, hover_mapping: dict[str, str], ) -> None: """Calculate the histogram labels Args: - hist_agg_label_h: The histogram agg label when oriented horizontally - hist_agg_label_v: The histogram agg label when oriented vertically + hist_agg_label: The histogram agg label + hist_orientation: The histogram orientation hover_mapping: The mapping of variables to columns """ # only one should be set - if hist_agg_label_h: + if hist_orientation == "h" and hist_agg_label: # a bar chart oriented horizontally has the histfunc on the x-axis - hover_mapping["x"] = hist_agg_label_h - elif hist_agg_label_v: - hover_mapping["y"] = hist_agg_label_v + hover_mapping["x"] = hist_agg_label + elif hist_orientation == "v" and hist_agg_label: + hover_mapping["y"] = hist_agg_label def add_axis_titles( custom_call_args: dict[str, Any], hover_mapping: list[dict[str, str]], - hist_agg_label_h: str | None, - hist_agg_label_v: str | None, + hist_agg_label: str | None, heatmap_agg_label: str | None, ) -> None: """Add axis titles. Generally, this only applies when there is a list variable @@ -916,8 +915,7 @@ def add_axis_titles( custom_call_args: The custom_call_args that are used to create hover and axis titles hover_mapping: The mapping of variables to columns - hist_agg_label_h: The histogram agg label when oriented horizontally - hist_agg_label_v: The histogram agg label when oriented vertically + hist_agg_label: The histogram agg label heatmap_agg_label: The aggregate density heatmap column title """ @@ -927,7 +925,7 @@ def add_axis_titles( new_xaxis_titles = None new_yaxis_titles = None - if hist_agg_label_h or hist_agg_label_v: + if hist_agg_label: # hist labels are already set up in the mapping new_xaxis_titles = [hover_mapping[0].get("x", None)] new_yaxis_titles = [hover_mapping[0].get("y", None)] @@ -953,7 +951,7 @@ def create_hover_and_axis_titles( hover_mapping: list[dict[str, str]], ) -> Generator[dict[str, Any], None, None]: """Create hover text and axis titles. There are three main behaviors. - First is "current_col", "current_var", and "pivot_vars" are specified in + First is "current_col", "current_var", and "stacked_column_names" are specified in "custom_call_args". In this case, there is a list of variables, but they are layered outside the generate function. @@ -983,19 +981,19 @@ def create_hover_and_axis_titles( Yields: Dicts containing hover updates """ - types = get_list_var_info(data_cols) + types = get_list_param_info(data_cols) labels = custom_call_args.get("labels", None) - hist_agg_label_h = custom_call_args.get("hist_agg_label_h", None) - hist_agg_label_v = custom_call_args.get("hist_agg_label_v", None) + hist_agg_label = custom_call_args.get("hist_agg_label", None) + hist_orientation = custom_call_args.get("hist_orientation", None) heatmap_agg_label = custom_call_args.get("heatmap_agg_label", None) current_partition = custom_call_args.get("current_partition", {}) compute_labels( hover_mapping, - hist_agg_label_h, - hist_agg_label_v, + hist_agg_label, + hist_orientation, heatmap_agg_label, types, labels, @@ -1011,8 +1009,7 @@ def create_hover_and_axis_titles( add_axis_titles( custom_call_args, hover_mapping, - hist_agg_label_h, - hist_agg_label_v, + hist_agg_label, heatmap_agg_label, ) diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py index cc0b33cce..5a18b576a 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py @@ -9,7 +9,7 @@ from deephaven.table import Table, PartitionedTable from deephaven import pandas as dhpd -from deephaven import merge, empty_table +from deephaven import merge from ._layer import atomic_layer from .. import DeephavenFigure @@ -109,10 +109,10 @@ class PartitionManager: Attributes: by_vars: set[str]: The set of by_vars that can be used in a plot by - list_var: str: "x" or "y" depending on which var is a list - cols: str | list: The columns set by the list_var - pivot_vars: dict[str, str]: A dictionary that stores the "real" column - names if there is a list_var. This is needed in case the column names + list_param: str: "x" or "y" depending on which param is a list + cols: str | list: The columns set by the list_param + stacked_column_names: dict[str, str]: A dictionary that stores the "real" column + names if there is a list_param. This is needed in case the column names used are already in the table. has_color: bool: True if this figure has user set color, False otherwise facet_row: str: The facet row @@ -146,9 +146,9 @@ def __init__( ): self.by = None self.by_vars = None - self.list_var = None + self.list_param = None self.cols = None - self.pivot_vars = {} + self.stacked_column_names = {} self.has_color = None self.facet_row = None self.facet_col = None @@ -199,13 +199,13 @@ def set_long_mode_variables(self) -> None: self.groups.discard("supports_lists") return - self.list_var = var + self.list_param = var self.cols = cols - args["current_var"] = self.list_var + args["current_var"] = self.list_param - self.pivot_vars = get_unique_names(table, ["variable", "value"]) - self.args["pivot_vars"] = self.pivot_vars + self.stacked_column_names = get_unique_names(table, ["variable", "value"]) + self.args["stacked_column_names"] = self.stacked_column_names def convert_table_to_long_mode( self, @@ -227,7 +227,7 @@ def convert_table_to_long_mode( # if there is no plot by arg, the variable column becomes it if not self.args.get("by", None): - args["by"] = self.pivot_vars["variable"] + args["by"] = self.stacked_column_names["variable"] args["table"] = self.to_long_mode(table, self.cols) @@ -418,7 +418,11 @@ def process_partitions(self) -> Table | PartitionedTable: # preprocessor needs to be initialized after the always attached arguments are found self.preprocessor = Preprocessor( - args, self.groups, self.always_attached, self.pivot_vars, self.list_var + args, + self.groups, + self.always_attached, + self.stacked_column_names, + self.list_param, ) if partition_cols: @@ -468,12 +472,14 @@ def build_ternary_chain(self, cols: list[str]) -> str: Returns: The ternary string that builds the new column """ - ternary_string = f"{self.pivot_vars['value']} = " + ternary_string = f"{self.stacked_column_names['value']} = " for i, col in enumerate(cols): if i == len(cols) - 1: ternary_string += f"{col}" else: - ternary_string += f"{self.pivot_vars['variable']} == `{col}` ? {col} : " + ternary_string += ( + f"{self.stacked_column_names['variable']} == `{col}` ? {col} : " + ) return ternary_string def to_long_mode(self, table: Table, cols: list[str] | None) -> Table: @@ -494,7 +500,7 @@ def to_long_mode(self, table: Table, cols: list[str] | None) -> Table: new_tables = [] for col in cols: new_tables.append( - table.update_view(f"{self.pivot_vars['variable']} = `{col}`") + table.update_view(f"{self.stacked_column_names['variable']} = `{col}`") ) merged = merge(new_tables) @@ -545,7 +551,9 @@ def table_partition_generator( Yields: The tuple of table and current partition """ - column = self.pivot_vars["value"] if self.pivot_vars else None + column = ( + self.stacked_column_names["value"] if self.stacked_column_names else None + ) if self.preprocessor: tables = self.preprocessor.preprocess_partitioned_tables( self.constituents, column @@ -571,9 +579,13 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]: # if a tuple is returned here, it was preprocessed already so pivots aren't needed table, arg_update = table args.update(arg_update) - elif self.pivot_vars and self.pivot_vars["value"] and self.list_var: + elif ( + self.stacked_column_names + and self.stacked_column_names["value"] + and self.list_param + ): # there is a list of variables, so replace them with the combined column - args[self.list_var] = self.pivot_vars["value"] + args[self.list_param] = self.stacked_column_names["value"] args["current_partition"] = current_partition @@ -680,8 +692,8 @@ def create_figure(self) -> DeephavenFigure: # by color (colors might be used multiple times) self.marg_args["table"] = self.partitioned_table - if self.pivot_vars and self.pivot_vars["value"]: - self.marg_args[self.list_var] = self.pivot_vars["value"] + if self.stacked_column_names and self.stacked_column_names["value"]: + self.marg_args[self.list_param] = self.stacked_column_names["value"] self.marg_args["color"] = self.marg_color diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/_private_utils.py b/plugins/plotly-express/src/deephaven/plot/express/plots/_private_utils.py index cd287c0df..41e9fe920 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/_private_utils.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/_private_utils.py @@ -233,8 +233,12 @@ def create_deephaven_figure( # this is a marginal, so provide an empty update function update_wrapper = lambda x: x - list_var = partitioned.list_var - pivot_col = partitioned.pivot_vars["value"] if partitioned.pivot_vars else None + list_param = partitioned.list_param + pivot_col = ( + partitioned.stacked_column_names["value"] + if partitioned.stacked_column_names + else None + ) by = partitioned.by update = {} @@ -243,9 +247,9 @@ def create_deephaven_figure( # by needs to be updated as if there is a list variable but by is None, the pivot column is used as the by update["by"] = by - if list_var: + if list_param: # if there is a list variable, update the list variable to the pivot column - update[list_var] = pivot_col + update[list_param] = pivot_col return ( update_wrapper(partitioned.create_figure()), diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py index 5a369bd3d..8320d29c2 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py @@ -54,10 +54,10 @@ class HistPreprocessor(UnivariateAwarePreprocessor): def __init__( self, args: dict[str, Any], - pivot_vars: dict[str, str], - list_var: str | None = None, + stacked_column_names: dict[str, str], + list_param: str | None = None, ): - super().__init__(args, pivot_vars, list_var) + super().__init__(args, stacked_column_names, list_param) self.range_table = None self.names = {} self.nbins = args.pop("nbins", 10) @@ -296,5 +296,6 @@ def preprocess_partitioned_tables( yield bin_counts.view([f"{bin_col} = {bin_mid}", new_agg_col]), { self.agg_var: new_agg_col, self.bin_var: bin_col, - f"hist_agg_label_{self.orientation}": hist_agg_label, + f"hist_agg_label": hist_agg_label, + f"hist_orientation": self.orientation, } diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py index 596c71de3..398de015a 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py @@ -17,8 +17,8 @@ class Preprocessor: Preprocessor for tables Attributes: - pivot_vars: dict[str, str]: A dictionary that stores the "real" column - names if there is a list_var. This is needed in case the column names + stacked_column_names: dict[str, str]: A dictionary that stores the "real" column + names if there is a list_param. This is needed in case the column names used are already in the table. always_attached: dict[tuple[str, str], tuple[dict[str, str], list[str], str]: The dict mapping the arg and column @@ -33,15 +33,15 @@ def __init__( args: dict[str, Any], groups: set[str], always_attached: dict[tuple[str, str], tuple[dict[str, str], list[str], str]], - pivot_vars: dict[str, str], - list_var: str | None, + stacked_column_names: dict[str, str], + list_param: str | None, ): self.args = args self.groups = groups self.preprocesser = None self.always_attached = always_attached - self.pivot_vars = pivot_vars - self.list_var = list_var + self.stacked_column_names = stacked_column_names + self.list_param = list_param self.prepare_preprocess() def prepare_preprocess(self) -> None: @@ -50,7 +50,7 @@ def prepare_preprocess(self) -> None: """ if "preprocess_hist" in self.groups: self.preprocesser = HistPreprocessor( - self.args, self.pivot_vars, self.list_var + self.args, self.stacked_column_names, self.list_param ) elif "preprocess_freq" in self.groups: self.preprocesser = FreqPreprocessor(self.args) diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/UnivariateAwarePreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/UnivariateAwarePreprocessor.py index b551859b4..c7760936f 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/UnivariateAwarePreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/UnivariateAwarePreprocessor.py @@ -1,9 +1,12 @@ from __future__ import annotations -from typing import Any +import abc +from typing import Any, Generator +from deephaven.table import Table -class UnivariateAwarePreprocessor: + +class UnivariateAwarePreprocessor(abc.ABC): """ A preprocessor that stores useful args for plots where possibly one of x or y or both can be specified, which impacts the orientation of the plot in ways that affect the preprocessing. @@ -11,8 +14,10 @@ class UnivariateAwarePreprocessor: Args: args: Figure creation args - pivot_vars: The vars with new column names if a list was passed in - list_var: The var that was passed in as a list + stacked_column_names: A dictionary that stores the "real" column + names if there is a list_param. This is needed in case the column names + used are already in the table. + list_param: The param that was passed in as a list Attributes: args: dict[str, str]: Figure creation args @@ -31,8 +36,8 @@ class UnivariateAwarePreprocessor: def __init__( self, args: dict[str, Any], - pivot_vars: dict[str, str] | None = None, - list_var: str | None = None, + stacked_column_names: dict[str, str] | None = None, + list_param: str | None = None, ): self.args = args self.table = args["table"] @@ -41,19 +46,19 @@ def __init__( self.bin_var = "x" if self.orientation == "v" else "y" self.agg_var = "y" if self.bin_var == "x" else "x" self.bin_col: str = ( - pivot_vars["value"] - if pivot_vars and list_var and list_var == self.bin_var + stacked_column_names["value"] + if stacked_column_names and list_param and list_param == self.bin_var else args[self.bin_var] ) if self.args.get(self.agg_var): self.agg_col: str = ( - pivot_vars["value"] - if pivot_vars and list_var and list_var == self.agg_var + stacked_column_names["value"] + if stacked_column_names and list_param and list_param == self.agg_var else args[self.agg_var] ) else: - # if bar_var is not set, the value column is the same as the axis column + # if agg_var is not set, the value column is the same as the axis column # because both the axis bins and value are computed from the same inputs self.agg_col = self.bin_col @@ -71,9 +76,25 @@ def calculate_bar_orientation(self): # Note that this will also be the default if both are specified # plotly express does some more sophisticated checking for data types # when both are specified but categorical data will fail due to the - # engine preprocessing in our implementation so just assume vertical + # engine preprocessing in our implementation so just assume verticals return "v" elif y: return "h" raise ValueError("Could not determine orientation") + + @abc.abstractmethod + def preprocess_partitioned_tables( + self, tables: list[Table], column: str | None = None + ) -> Generator[tuple[Table, dict[str, str | None]], None, None]: + """ + Preprocess the tables into the appropriate format for the plot. + + Args: + tables: A list of tables to preprocess + column: The column to aggregate on + + Returns: + A tuple containing (the new table, an update to make to the args) + """ + raise NotImplementedError