Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Make dx histogram behavior consistent with px #1002

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions plugins/plotly-express/docs/histogram.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,28 @@ hist_3_bins = dx.histogram(setosa, x="SepalLength", nbins=3)
hist_8_bins = dx.histogram(setosa, x="SepalLength", nbins=8)
```

### Bin and aggregate on different columns

If the plot orientation is vertical (`"v"`), the `x` column is binned and the `y` column is aggregated. The operations are flipped if the plot orientation is horizontal.


```python order=hist_v,hist_h,hist_avg,iris
import deephaven.plot.express as dx
iris = dx.data.iris()

# subset to get specific species
setosa = iris.where("Species == `setosa`")

# The default orientation is "v" (vertical) and the default aggregation function is "sum"
hist_v = dx.histogram(setosa, x="SepalLength", y="SepalWidth")

# Control the plot orientation using orientation
hist_h = dx.histogram(setosa, x="SepalLength", y="SepalWidth", orientation="h")

# Control the aggregation function using histfunc
hist_avg = dx.histogram(setosa, x="SepalLength", y="SepalWidth", histfunc="avg")
```

### Distributions of several groups

Histograms can also be used to compare the distributional properties of different groups of data, though they may be a little harder to read than [box plots](box.md) or [violin plots](violin.md). Pass the name of the grouping column(s) to the `by` argument.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@
"current_col",
"current_var",
"labels",
"hist_val_name",
"pivot_vars",
"hist_agg_label",
"hist_orientation",
"stacked_column_names",
"current_partition",
"colors",
"unsafe_update_figure",
Expand Down Expand Up @@ -684,7 +685,7 @@ def handle_custom_args(
return trace_generator


def get_list_var_info(data_cols: Mapping[str, str | list[str]]) -> set[str]:
def get_list_param_info(data_cols: Mapping[str, str | list[str]]) -> set[str]:
"""Extract the variable that is a list.

Args:
Expand Down Expand Up @@ -824,7 +825,8 @@ def hover_text_generator(

def compute_labels(
hover_mapping: list[dict[str, str]],
hist_val_name: str | None,
hist_agg_label: str | None,
hist_orientation: str | None,
heatmap_agg_label: str | None,
# hover_data - todo, dependent on arrays supported in data mappings
types: set[str],
Expand All @@ -837,7 +839,8 @@ def compute_labels(

Args:
hover_mapping: The mapping of variables to columns
hist_val_name: The histogram name for the value axis, generally histfunc
hist_agg_label: The histogram agg label
hist_orientation: The histogram orientation
heatmap_agg_label: The aggregate density heatmap column title
types: Any types of this chart that require special processing
labels: A dictionary of old column name to new column name mappings
Expand All @@ -847,7 +850,7 @@ def compute_labels(
the renamed current_col
"""

calculate_hist_labels(hist_val_name, hover_mapping[0])
calculate_hist_labels(hist_agg_label, hist_orientation, hover_mapping[0])

calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0], labels)

Expand Down Expand Up @@ -880,27 +883,30 @@ def calculate_density_heatmap_labels(


def calculate_hist_labels(
hist_val_name: str | None, current_mapping: dict[str, str]
hist_agg_label: str | None,
hist_orientation: str | None,
hover_mapping: dict[str, str],
) -> None:
"""Calculate the histogram labels

Args:
hist_val_name: The histogram name for the value axis, generally histfunc
current_mapping: The mapping of variables to columns
hist_agg_label: The histogram agg label
hist_orientation: The histogram orientation
hover_mapping: The mapping of variables to columns

"""
if hist_val_name:
# swap the names
current_mapping["x"], current_mapping["y"] = (
current_mapping["y"],
current_mapping["x"],
)
# only one should be set
if hist_orientation == "h" and hist_agg_label:
# a bar chart oriented horizontally has the histfunc on the x-axis
hover_mapping["x"] = hist_agg_label
elif hist_orientation == "v" and hist_agg_label:
hover_mapping["y"] = hist_agg_label


def add_axis_titles(
custom_call_args: dict[str, Any],
hover_mapping: list[dict[str, str]],
hist_val_name: str | None,
hist_agg_label: str | None,
heatmap_agg_label: str | None,
) -> None:
"""Add axis titles. Generally, this only applies when there is a list variable
Expand All @@ -909,7 +915,7 @@ def add_axis_titles(
custom_call_args: The custom_call_args that are used to
create hover and axis titles
hover_mapping: The mapping of variables to columns
hist_val_name: The histogram name for the value axis, generally histfunc
hist_agg_label: The histogram agg label
heatmap_agg_label: The aggregate density heatmap column title

"""
Expand All @@ -919,8 +925,8 @@ def add_axis_titles(
new_xaxis_titles = None
new_yaxis_titles = None

if hist_val_name:
# hist names are already set up in the mapping
if hist_agg_label:
# hist labels are already set up in the mapping
new_xaxis_titles = [hover_mapping[0].get("x", None)]
new_yaxis_titles = [hover_mapping[0].get("y", None)]

Expand All @@ -945,7 +951,7 @@ def create_hover_and_axis_titles(
hover_mapping: list[dict[str, str]],
) -> Generator[dict[str, Any], None, None]:
"""Create hover text and axis titles. There are three main behaviors.
First is "current_col", "current_var", and "pivot_vars" are specified in
First is "current_col", "current_var", and "stacked_column_names" are specified in
"custom_call_args".
In this case, there is a list of variables, but they are layered outside
the generate function.
Expand Down Expand Up @@ -975,17 +981,19 @@ def create_hover_and_axis_titles(
Yields:
Dicts containing hover updates
"""
types = get_list_var_info(data_cols)
types = get_list_param_info(data_cols)

labels = custom_call_args.get("labels", None)
hist_val_name = custom_call_args.get("hist_val_name", None)
hist_agg_label = custom_call_args.get("hist_agg_label", None)
hist_orientation = custom_call_args.get("hist_orientation", None)
heatmap_agg_label = custom_call_args.get("heatmap_agg_label", None)

current_partition = custom_call_args.get("current_partition", {})

compute_labels(
hover_mapping,
hist_val_name,
hist_agg_label,
hist_orientation,
heatmap_agg_label,
types,
labels,
Expand All @@ -998,7 +1006,12 @@ def create_hover_and_axis_titles(
# it's possible that heatmap_agg_label was relabeled, so grab the new label
heatmap_agg_label = hover_mapping[0]["z"]

add_axis_titles(custom_call_args, hover_mapping, hist_val_name, heatmap_agg_label)
add_axis_titles(
custom_call_args,
hover_mapping,
hist_agg_label,
heatmap_agg_label,
)

return hover_text

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from deephaven.table import Table, PartitionedTable
from deephaven import pandas as dhpd
from deephaven import merge, empty_table
from deephaven import merge

from ._layer import atomic_layer
from .. import DeephavenFigure
Expand Down Expand Up @@ -109,10 +109,10 @@ class PartitionManager:

Attributes:
by_vars: set[str]: The set of by_vars that can be used in a plot by
list_var: str: "x" or "y" depending on which var is a list
cols: str | list: The columns set by the list_var
pivot_vars: dict[str, str]: A dictionary that stores the "real" column
names if there is a list_var. This is needed in case the column names
list_param: str: "x" or "y" depending on which param is a list
cols: str | list: The columns set by the list_param
stacked_column_names: dict[str, str]: A dictionary that stores the "real" column
names if there is a list_param. This is needed in case the column names
used are already in the table.
has_color: bool: True if this figure has user set color, False otherwise
facet_row: str: The facet row
Expand Down Expand Up @@ -146,9 +146,9 @@ def __init__(
):
self.by = None
self.by_vars = None
self.list_var = None
self.list_param = None
self.cols = None
self.pivot_vars = {}
self.stacked_column_names = {}
self.has_color = None
self.facet_row = None
self.facet_col = None
Expand Down Expand Up @@ -199,13 +199,13 @@ def set_long_mode_variables(self) -> None:
self.groups.discard("supports_lists")
return

self.list_var = var
self.list_param = var
self.cols = cols

args["current_var"] = self.list_var
args["current_var"] = self.list_param

self.pivot_vars = get_unique_names(table, ["variable", "value"])
self.args["pivot_vars"] = self.pivot_vars
self.stacked_column_names = get_unique_names(table, ["variable", "value"])
self.args["stacked_column_names"] = self.stacked_column_names

def convert_table_to_long_mode(
self,
Expand All @@ -227,7 +227,7 @@ def convert_table_to_long_mode(

# if there is no plot by arg, the variable column becomes it
if not self.args.get("by", None):
args["by"] = self.pivot_vars["variable"]
args["by"] = self.stacked_column_names["variable"]

args["table"] = self.to_long_mode(table, self.cols)

Expand Down Expand Up @@ -418,7 +418,11 @@ def process_partitions(self) -> Table | PartitionedTable:

# preprocessor needs to be initialized after the always attached arguments are found
self.preprocessor = Preprocessor(
args, self.groups, self.always_attached, self.pivot_vars
args,
self.groups,
self.always_attached,
self.stacked_column_names,
self.list_param,
)

if partition_cols:
Expand Down Expand Up @@ -468,12 +472,14 @@ def build_ternary_chain(self, cols: list[str]) -> str:
Returns:
The ternary string that builds the new column
"""
ternary_string = f"{self.pivot_vars['value']} = "
ternary_string = f"{self.stacked_column_names['value']} = "
for i, col in enumerate(cols):
if i == len(cols) - 1:
ternary_string += f"{col}"
else:
ternary_string += f"{self.pivot_vars['variable']} == `{col}` ? {col} : "
ternary_string += (
f"{self.stacked_column_names['variable']} == `{col}` ? {col} : "
)
return ternary_string

def to_long_mode(self, table: Table, cols: list[str] | None) -> Table:
Expand All @@ -494,7 +500,7 @@ def to_long_mode(self, table: Table, cols: list[str] | None) -> Table:
new_tables = []
for col in cols:
new_tables.append(
table.update_view(f"{self.pivot_vars['variable']} = `{col}`")
table.update_view(f"{self.stacked_column_names['variable']} = `{col}`")
)

merged = merge(new_tables)
Expand Down Expand Up @@ -545,7 +551,9 @@ def table_partition_generator(
Yields:
The tuple of table and current partition
"""
column = self.pivot_vars["value"] if self.pivot_vars else None
column = (
self.stacked_column_names["value"] if self.stacked_column_names else None
)
if self.preprocessor:
tables = self.preprocessor.preprocess_partitioned_tables(
self.constituents, column
Expand All @@ -571,9 +579,13 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]:
# if a tuple is returned here, it was preprocessed already so pivots aren't needed
table, arg_update = table
args.update(arg_update)
elif self.pivot_vars and self.pivot_vars["value"] and self.list_var:
elif (
self.stacked_column_names
and self.stacked_column_names["value"]
and self.list_param
):
# there is a list of variables, so replace them with the combined column
args[self.list_var] = self.pivot_vars["value"]
args[self.list_param] = self.stacked_column_names["value"]

args["current_partition"] = current_partition

Expand Down Expand Up @@ -680,8 +692,8 @@ def create_figure(self) -> DeephavenFigure:
# by color (colors might be used multiple times)
self.marg_args["table"] = self.partitioned_table

if self.pivot_vars and self.pivot_vars["value"]:
self.marg_args[self.list_var] = self.pivot_vars["value"]
if self.stacked_column_names and self.stacked_column_names["value"]:
self.marg_args[self.list_param] = self.stacked_column_names["value"]

self.marg_args["color"] = self.marg_color

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,12 @@ def create_deephaven_figure(
# this is a marginal, so provide an empty update function
update_wrapper = lambda x: x

list_var = partitioned.list_var
pivot_col = partitioned.pivot_vars["value"] if partitioned.pivot_vars else None
list_param = partitioned.list_param
pivot_col = (
partitioned.stacked_column_names["value"]
if partitioned.stacked_column_names
else None
)
by = partitioned.by

update = {}
Expand All @@ -243,9 +247,9 @@ def create_deephaven_figure(
# by needs to be updated as if there is a list variable but by is None, the pivot column is used as the by
update["by"] = by

if list_var:
if list_param:
# if there is a list variable, update the list variable to the pivot column
update[list_var] = pivot_col
update[list_param] = pivot_col

return (
update_wrapper(partitioned.create_figure()),
Expand Down Expand Up @@ -489,7 +493,6 @@ def shared_histogram(is_marginal: bool = True, **args: Any) -> DeephavenFigure:
set_all(args, HISTOGRAM_DEFAULTS)

args["bargap"] = 0
args["hist_val_name"] = args.get("histfunc", "count")

func = px.bar
groups = {"bar", "preprocess_hist", "supports_lists"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._private_utils import validate_common_args, process_args
from ..shared import default_callback
from ..deephaven_figure import generate_figure, DeephavenFigure
from ..types import PartitionableTableLike
from ..types import PartitionableTableLike, Orientation


def bar(
Expand Down Expand Up @@ -42,6 +42,7 @@ def bar(
range_color: list[float] | None = None,
color_continuous_midpoint: float | None = None,
opacity: float | None = None,
orientation: Orientation | None = None,
barmode: str = "relative",
log_x: bool = False,
log_y: bool = False,
Expand Down Expand Up @@ -114,6 +115,12 @@ def bar(
color_continuous_midpoint: A number that is the midpoint of the color axis
opacity: Opacity to apply to all markers. 0 is completely transparent
and 1 is completely opaque.
orientation: The orientation of the bars.
If 'v', the bars are vertical.
If 'h', the bars are horizontal.
Defaults to 'v' if only `x` is specified.
Defaults to 'h' if only `y` is specified.
Defaults to 'v' if both `x` and `y` are specified unless `x` is passed only numeric columns and `y` is not.
barmode: If 'relative', bars are stacked. If 'overlay', bars are drawn on top
of each other. If 'group', bars are drawn next to each other.
log_x: A boolean or list of booleans that specify if
Expand Down
Loading
Loading