Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jnumainville committed Dec 18, 2024
1 parent ba4ccc0 commit b8a512a
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@
"current_col",
"current_var",
"labels",
"hist_agg_label_h",
"hist_agg_label_v",
"pivot_vars",
"hist_agg_label",
"hist_orientation",
"stacked_column_names",
"current_partition",
"colors",
"unsafe_update_figure",
Expand Down Expand Up @@ -685,7 +685,7 @@ def handle_custom_args(
return trace_generator


def get_list_var_info(data_cols: Mapping[str, str | list[str]]) -> set[str]:
def get_list_param_info(data_cols: Mapping[str, str | list[str]]) -> set[str]:
"""Extract the variable that is a list.
Args:
Expand Down Expand Up @@ -825,8 +825,8 @@ def hover_text_generator(

def compute_labels(
hover_mapping: list[dict[str, str]],
hist_agg_label_h: str | None,
hist_agg_label_v: str | None,
hist_agg_label: str | None,
hist_orientation: str | None,
heatmap_agg_label: str | None,
# hover_data - todo, dependent on arrays supported in data mappings
types: set[str],
Expand All @@ -839,8 +839,8 @@ def compute_labels(
Args:
hover_mapping: The mapping of variables to columns
hist_agg_label_h: The histogram agg label when oriented horizontally
hist_agg_label_v: The histogram agg label when oriented vertically
hist_agg_label: The histogram agg label
hist_orientation: The histogram orientation
heatmap_agg_label: The aggregate density heatmap column title
types: Any types of this chart that require special processing
labels: A dictionary of old column name to new column name mappings
Expand All @@ -850,7 +850,7 @@ def compute_labels(
the renamed current_col
"""

calculate_hist_labels(hist_agg_label_h, hist_agg_label_v, hover_mapping[0])
calculate_hist_labels(hist_agg_label, hist_orientation, hover_mapping[0])

calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0], labels)

Expand Down Expand Up @@ -883,31 +883,30 @@ def calculate_density_heatmap_labels(


def calculate_hist_labels(
hist_agg_label_h: str | None,
hist_agg_label_v: str | None,
hist_agg_label: str | None,
hist_orientation: str | None,
hover_mapping: dict[str, str],
) -> None:
"""Calculate the histogram labels
Args:
hist_agg_label_h: The histogram agg label when oriented horizontally
hist_agg_label_v: The histogram agg label when oriented vertically
hist_agg_label: The histogram agg label
hist_orientation: The histogram orientation
hover_mapping: The mapping of variables to columns
"""
# only one should be set
if hist_agg_label_h:
if hist_orientation == "h" and hist_agg_label:
# a bar chart oriented horizontally has the histfunc on the x-axis
hover_mapping["x"] = hist_agg_label_h
elif hist_agg_label_v:
hover_mapping["y"] = hist_agg_label_v
hover_mapping["x"] = hist_agg_label
elif hist_orientation == "v" and hist_agg_label:
hover_mapping["y"] = hist_agg_label


def add_axis_titles(
custom_call_args: dict[str, Any],
hover_mapping: list[dict[str, str]],
hist_agg_label_h: str | None,
hist_agg_label_v: str | None,
hist_agg_label: str | None,
heatmap_agg_label: str | None,
) -> None:
"""Add axis titles. Generally, this only applies when there is a list variable
Expand All @@ -916,8 +915,7 @@ def add_axis_titles(
custom_call_args: The custom_call_args that are used to
create hover and axis titles
hover_mapping: The mapping of variables to columns
hist_agg_label_h: The histogram agg label when oriented horizontally
hist_agg_label_v: The histogram agg label when oriented vertically
hist_agg_label: The histogram agg label
heatmap_agg_label: The aggregate density heatmap column title
"""
Expand All @@ -927,7 +925,7 @@ def add_axis_titles(
new_xaxis_titles = None
new_yaxis_titles = None

if hist_agg_label_h or hist_agg_label_v:
if hist_agg_label:
# hist labels are already set up in the mapping
new_xaxis_titles = [hover_mapping[0].get("x", None)]
new_yaxis_titles = [hover_mapping[0].get("y", None)]
Expand All @@ -953,7 +951,7 @@ def create_hover_and_axis_titles(
hover_mapping: list[dict[str, str]],
) -> Generator[dict[str, Any], None, None]:
"""Create hover text and axis titles. There are three main behaviors.
First is "current_col", "current_var", and "pivot_vars" are specified in
First is "current_col", "current_var", and "stacked_column_names" are specified in
"custom_call_args".
In this case, there is a list of variables, but they are layered outside
the generate function.
Expand Down Expand Up @@ -983,19 +981,19 @@ def create_hover_and_axis_titles(
Yields:
Dicts containing hover updates
"""
types = get_list_var_info(data_cols)
types = get_list_param_info(data_cols)

labels = custom_call_args.get("labels", None)
hist_agg_label_h = custom_call_args.get("hist_agg_label_h", None)
hist_agg_label_v = custom_call_args.get("hist_agg_label_v", None)
hist_agg_label = custom_call_args.get("hist_agg_label", None)
hist_orientation = custom_call_args.get("hist_orientation", None)
heatmap_agg_label = custom_call_args.get("heatmap_agg_label", None)

current_partition = custom_call_args.get("current_partition", {})

compute_labels(
hover_mapping,
hist_agg_label_h,
hist_agg_label_v,
hist_agg_label,
hist_orientation,
heatmap_agg_label,
types,
labels,
Expand All @@ -1011,8 +1009,7 @@ def create_hover_and_axis_titles(
add_axis_titles(
custom_call_args,
hover_mapping,
hist_agg_label_h,
hist_agg_label_v,
hist_agg_label,
heatmap_agg_label,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from deephaven.table import Table, PartitionedTable
from deephaven import pandas as dhpd
from deephaven import merge, empty_table
from deephaven import merge

from ._layer import atomic_layer
from .. import DeephavenFigure
Expand Down Expand Up @@ -109,10 +109,10 @@ class PartitionManager:
Attributes:
by_vars: set[str]: The set of by_vars that can be used in a plot by
list_var: str: "x" or "y" depending on which var is a list
cols: str | list: The columns set by the list_var
pivot_vars: dict[str, str]: A dictionary that stores the "real" column
names if there is a list_var. This is needed in case the column names
list_param: str: "x" or "y" depending on which param is a list
cols: str | list: The columns set by the list_param
stacked_column_names: dict[str, str]: A dictionary that stores the "real" column
names if there is a list_param. This is needed in case the column names
used are already in the table.
has_color: bool: True if this figure has user set color, False otherwise
facet_row: str: The facet row
Expand Down Expand Up @@ -146,9 +146,9 @@ def __init__(
):
self.by = None
self.by_vars = None
self.list_var = None
self.list_param = None
self.cols = None
self.pivot_vars = {}
self.stacked_column_names = {}
self.has_color = None
self.facet_row = None
self.facet_col = None
Expand Down Expand Up @@ -199,13 +199,13 @@ def set_long_mode_variables(self) -> None:
self.groups.discard("supports_lists")
return

self.list_var = var
self.list_param = var
self.cols = cols

args["current_var"] = self.list_var
args["current_var"] = self.list_param

self.pivot_vars = get_unique_names(table, ["variable", "value"])
self.args["pivot_vars"] = self.pivot_vars
self.stacked_column_names = get_unique_names(table, ["variable", "value"])
self.args["stacked_column_names"] = self.stacked_column_names

def convert_table_to_long_mode(
self,
Expand All @@ -227,7 +227,7 @@ def convert_table_to_long_mode(

# if there is no plot by arg, the variable column becomes it
if not self.args.get("by", None):
args["by"] = self.pivot_vars["variable"]
args["by"] = self.stacked_column_names["variable"]

args["table"] = self.to_long_mode(table, self.cols)

Expand Down Expand Up @@ -418,7 +418,11 @@ def process_partitions(self) -> Table | PartitionedTable:

# preprocessor needs to be initialized after the always attached arguments are found
self.preprocessor = Preprocessor(
args, self.groups, self.always_attached, self.pivot_vars, self.list_var
args,
self.groups,
self.always_attached,
self.stacked_column_names,
self.list_param,
)

if partition_cols:
Expand Down Expand Up @@ -468,12 +472,14 @@ def build_ternary_chain(self, cols: list[str]) -> str:
Returns:
The ternary string that builds the new column
"""
ternary_string = f"{self.pivot_vars['value']} = "
ternary_string = f"{self.stacked_column_names['value']} = "
for i, col in enumerate(cols):
if i == len(cols) - 1:
ternary_string += f"{col}"
else:
ternary_string += f"{self.pivot_vars['variable']} == `{col}` ? {col} : "
ternary_string += (
f"{self.stacked_column_names['variable']} == `{col}` ? {col} : "
)
return ternary_string

def to_long_mode(self, table: Table, cols: list[str] | None) -> Table:
Expand All @@ -494,7 +500,7 @@ def to_long_mode(self, table: Table, cols: list[str] | None) -> Table:
new_tables = []
for col in cols:
new_tables.append(
table.update_view(f"{self.pivot_vars['variable']} = `{col}`")
table.update_view(f"{self.stacked_column_names['variable']} = `{col}`")
)

merged = merge(new_tables)
Expand Down Expand Up @@ -545,7 +551,9 @@ def table_partition_generator(
Yields:
The tuple of table and current partition
"""
column = self.pivot_vars["value"] if self.pivot_vars else None
column = (
self.stacked_column_names["value"] if self.stacked_column_names else None
)
if self.preprocessor:
tables = self.preprocessor.preprocess_partitioned_tables(
self.constituents, column
Expand All @@ -571,9 +579,13 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]:
# if a tuple is returned here, it was preprocessed already so pivots aren't needed
table, arg_update = table
args.update(arg_update)
elif self.pivot_vars and self.pivot_vars["value"] and self.list_var:
elif (
self.stacked_column_names
and self.stacked_column_names["value"]
and self.list_param
):
# there is a list of variables, so replace them with the combined column
args[self.list_var] = self.pivot_vars["value"]
args[self.list_param] = self.stacked_column_names["value"]

args["current_partition"] = current_partition

Expand Down Expand Up @@ -680,8 +692,8 @@ def create_figure(self) -> DeephavenFigure:
# by color (colors might be used multiple times)
self.marg_args["table"] = self.partitioned_table

if self.pivot_vars and self.pivot_vars["value"]:
self.marg_args[self.list_var] = self.pivot_vars["value"]
if self.stacked_column_names and self.stacked_column_names["value"]:
self.marg_args[self.list_param] = self.stacked_column_names["value"]

self.marg_args["color"] = self.marg_color

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,12 @@ def create_deephaven_figure(
# this is a marginal, so provide an empty update function
update_wrapper = lambda x: x

list_var = partitioned.list_var
pivot_col = partitioned.pivot_vars["value"] if partitioned.pivot_vars else None
list_param = partitioned.list_param
pivot_col = (
partitioned.stacked_column_names["value"]
if partitioned.stacked_column_names
else None
)
by = partitioned.by

update = {}
Expand All @@ -243,9 +247,9 @@ def create_deephaven_figure(
# by needs to be updated as if there is a list variable but by is None, the pivot column is used as the by
update["by"] = by

if list_var:
if list_param:
# if there is a list variable, update the list variable to the pivot column
update[list_var] = pivot_col
update[list_param] = pivot_col

return (
update_wrapper(partitioned.create_figure()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class HistPreprocessor(UnivariateAwarePreprocessor):
def __init__(
self,
args: dict[str, Any],
pivot_vars: dict[str, str],
list_var: str | None = None,
stacked_column_names: dict[str, str],
list_param: str | None = None,
):
super().__init__(args, pivot_vars, list_var)
super().__init__(args, stacked_column_names, list_param)
self.range_table = None
self.names = {}
self.nbins = args.pop("nbins", 10)
Expand Down Expand Up @@ -296,5 +296,6 @@ def preprocess_partitioned_tables(
yield bin_counts.view([f"{bin_col} = {bin_mid}", new_agg_col]), {
self.agg_var: new_agg_col,
self.bin_var: bin_col,
f"hist_agg_label_{self.orientation}": hist_agg_label,
f"hist_agg_label": hist_agg_label,
f"hist_orientation": self.orientation,
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class Preprocessor:
Preprocessor for tables
Attributes:
pivot_vars: dict[str, str]: A dictionary that stores the "real" column
names if there is a list_var. This is needed in case the column names
stacked_column_names: dict[str, str]: A dictionary that stores the "real" column
names if there is a list_param. This is needed in case the column names
used are already in the table.
always_attached: dict[tuple[str, str],
tuple[dict[str, str], list[str], str]: The dict mapping the arg and column
Expand All @@ -33,15 +33,15 @@ def __init__(
args: dict[str, Any],
groups: set[str],
always_attached: dict[tuple[str, str], tuple[dict[str, str], list[str], str]],
pivot_vars: dict[str, str],
list_var: str | None,
stacked_column_names: dict[str, str],
list_param: str | None,
):
self.args = args
self.groups = groups
self.preprocesser = None
self.always_attached = always_attached
self.pivot_vars = pivot_vars
self.list_var = list_var
self.stacked_column_names = stacked_column_names
self.list_param = list_param
self.prepare_preprocess()

def prepare_preprocess(self) -> None:
Expand All @@ -50,7 +50,7 @@ def prepare_preprocess(self) -> None:
"""
if "preprocess_hist" in self.groups:
self.preprocesser = HistPreprocessor(
self.args, self.pivot_vars, self.list_var
self.args, self.stacked_column_names, self.list_param
)
elif "preprocess_freq" in self.groups:
self.preprocesser = FreqPreprocessor(self.args)
Expand Down
Loading

0 comments on commit b8a512a

Please sign in to comment.