Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include Time Spine Nodes in Dataflow Plan #1548

Open
wants to merge 8 commits into
base: court/simp7
Choose a base branch
from
  •  
  •  
  •  
127 changes: 110 additions & 17 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import logging
import time
from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, Union
Expand Down Expand Up @@ -87,9 +88,12 @@
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -646,14 +650,19 @@ def _build_derived_metric_output_node(
metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup
)
if metric_spec.has_time_offset and queried_agg_time_dimension_specs:
# TODO: move this to a helper method
time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs)
output_node = JoinToTimeSpineNode.create(
parent_node=output_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
offset_window=metric_spec.offset_window,
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
)

# TODO: fix bug here where filter specs are being included in when aggregating.
if len(metric_spec.filter_spec_set.all_filter_specs) > 0 or predicate_pushdown_state.time_range_constraint:
# FilterElementsNode will only be needed if there are where filter specs that were selected in the group by.
specs_in_filters = set(
Expand Down Expand Up @@ -1037,8 +1046,7 @@ def _find_source_node_recipe_non_cached(
)
# If metric_time is requested without metrics, choose appropriate time spine node to select those values from.
if linkable_specs_to_satisfy.metric_time_specs:
time_spine_source = self._choose_time_spine_source(linkable_specs_to_satisfy.metric_time_specs)
time_spine_node = self._source_node_set.time_spine_nodes[time_spine_source.base_granularity]
time_spine_node = self._choose_time_spine_metric_time_node(linkable_specs_to_satisfy.metric_time_specs)
candidate_nodes_for_right_side_of_join += [time_spine_node]
candidate_nodes_for_left_side_of_join += [time_spine_node]
default_join_type = SqlJoinType.FULL_OUTER
Expand Down Expand Up @@ -1077,7 +1085,7 @@ def _find_source_node_recipe_non_cached(
desired_linkable_specs=linkable_specs_to_satisfy_tuple,
nodes=candidate_nodes_for_right_side_of_join,
metric_time_dimension_reference=self._metric_time_dimension_reference,
time_spine_nodes=self._source_node_set.time_spine_nodes_tuple,
time_spine_metric_time_nodes=self._source_node_set.time_spine_metric_time_nodes_tuple,
)
logger.debug(
LazyFormat(
Expand Down Expand Up @@ -1124,7 +1132,7 @@ def _find_source_node_recipe_non_cached(
semantic_model_lookup=self._semantic_model_lookup,
nodes_available_for_joins=self._sort_by_suitability(candidate_nodes_for_right_side_of_join),
node_data_set_resolver=self._node_data_set_resolver,
time_spine_nodes=self._source_node_set.time_spine_nodes_tuple,
time_spine_metric_time_nodes=self._source_node_set.time_spine_metric_time_nodes_tuple,
)

# Dict from the node that contains the source node to the evaluation results.
Expand Down Expand Up @@ -1615,15 +1623,22 @@ def _build_aggregated_measure_from_measure_source_node(

# If querying an offset metric, join to time spine before aggregation.
if before_aggregation_time_spine_join_description and base_queried_agg_time_dimension_specs:
# TODO: move all of this to a helper function
assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, (
f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a "
f"new use case."
)
# This also uses the original time range constraint due to the application of the time window intervals
# in join rendering

join_on_time_dimension_spec = self._determine_time_spine_join_spec(
measure_properties=measure_properties, required_time_spine_specs=base_queried_agg_time_dimension_specs
)
required_time_spine_specs = (join_on_time_dimension_spec,) + base_queried_agg_time_dimension_specs
time_spine_node = self._build_time_spine_node(required_time_spine_specs)
unaggregated_measure_node = JoinToTimeSpineNode.create(
parent_node=unaggregated_measure_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs,
join_on_time_dimension_spec=join_on_time_dimension_spec,
offset_window=before_aggregation_time_spine_join_description.offset_window,
offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain,
join_type=before_aggregation_time_spine_join_description.join_type,
Expand Down Expand Up @@ -1666,10 +1681,13 @@ def _build_aggregated_measure_from_measure_source_node(
measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup
)
if after_aggregation_time_spine_join_description and queried_agg_time_dimension_specs:
# TODO: move all of this to a helper function
assert after_aggregation_time_spine_join_description.join_type is SqlJoinType.LEFT_OUTER, (
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
time_spine_required_specs = copy.deepcopy(queried_agg_time_dimension_specs)

# Find filters that contain only metric_time or agg_time_dimension. They will be applied to the time spine table.
agg_time_only_filters: List[WhereFilterSpec] = []
non_agg_time_filters: List[WhereFilterSpec] = []
Expand All @@ -1679,24 +1697,23 @@ def _build_aggregated_measure_from_measure_source_node(
)
if set(included_agg_time_specs) == set(filter_spec.linkable_spec_set.as_tuple):
agg_time_only_filters.append(filter_spec)
if filter_spec.linkable_spec_set.time_dimension_specs_with_custom_grain:
raise ValueError(
"Using custom granularity in filters for `join_to_timespine` metrics is not yet fully supported. "
"This feature is coming soon!"
)
for agg_time_spec in included_agg_time_specs:
if agg_time_spec not in time_spine_required_specs:
time_spine_required_specs.append(agg_time_spec)
else:
non_agg_time_filters.append(filter_spec)

# TODO: split this node into TimeSpineSourceNode and JoinToTimeSpineNode - then can use standard nodes here
# like JoinToCustomGranularityNode, WhereConstraintNode, etc.
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=queried_agg_time_dimension_specs,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
where_filter_specs=agg_time_only_filters,
)
output_node: DataflowPlanNode = JoinToTimeSpineNode.create(
parent_node=aggregate_measures_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
join_type=after_aggregation_time_spine_join_description.join_type,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
offset_window=after_aggregation_time_spine_join_description.offset_window,
offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain,
time_spine_filters=agg_time_only_filters,
)

# Since new rows might have been added due to time spine join, re-apply constraints here. Only re-apply filters
Expand Down Expand Up @@ -1812,3 +1829,79 @@ def _choose_time_spine_source(self, required_time_spine_specs: Sequence[TimeDime
required_time_spine_specs=required_time_spine_specs,
time_spine_sources=self._source_node_builder.time_spine_sources,
)

def _choose_time_spine_metric_time_node(
self, required_time_spine_specs: Sequence[TimeDimensionSpec]
) -> MetricTimeDimensionTransformNode:
"""Return the MetricTimeDimensionTransform time spine node needed to satisfy the specs."""
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
return self._source_node_set.time_spine_metric_time_nodes[time_spine_source.base_granularity]

def _choose_time_spine_read_node(self, time_spine_source: TimeSpineSource) -> ReadSqlSourceNode:
"""Return the MetricTimeDimensionTransform time spine node needed to satisfy the specs."""
return self._source_node_set.time_spine_read_nodes[time_spine_source.base_granularity]

def _build_time_spine_node(
self,
queried_time_spine_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec] = (),
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> DataflowPlanNode:
"""Return the time spine node needed to satisfy the specs."""
required_time_spine_spec_set = self.__get_required_linkable_specs(
queried_linkable_specs=LinkableSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
filter_specs=where_filter_specs,
)
required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs

# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
time_spine_node = TransformTimeDimensionsNode.create(
parent_node=self._choose_time_spine_read_node(time_spine_source),
requested_time_dimension_specs=required_time_spine_specs,
)

# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
filter_to_specs=InstanceSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
time_range_constraint=time_range_constraint,
where_filter_specs=where_filter_specs,
distinct=should_dedupe,
)

def _sort_by_base_granularity(self, time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]:
"""Sort the time dimensions by their base granularity.

Specs with date part will come after specs without it. Standard grains will come before custom.
"""
return sorted(
time_dimension_specs,
key=lambda spec: (
spec.date_part is not None,
spec.time_granularity.is_custom_granularity,
spec.time_granularity.base_granularity.to_int(),
),
)

def _determine_time_spine_join_spec(
self, measure_properties: MeasureSpecProperties, required_time_spine_specs: Tuple[TimeDimensionSpec, ...]
) -> TimeDimensionSpec:
"""Determine the spec to join on for a time spine join.

Defaults to metric_time if it is included in the request, else the agg_time_dimension.
Will use the smallest available grain for the meeasure.
"""
join_spec_grain = ExpandedTimeGranularity.from_time_granularity(measure_properties.agg_time_dimension_grain)
join_on_time_dimension_spec = DataSet.metric_time_dimension_spec(time_granularity=join_spec_grain)
if not LinkableSpecSet(time_dimension_specs=required_time_spine_specs).contains_metric_time:
sample_agg_time_dimension_spec = required_time_spine_specs[0]
join_on_time_dimension_spec = sample_agg_time_dimension_spec.with_grain_and_date_part(
time_granularity=join_spec_grain, date_part=None
)
return join_on_time_dimension_spec
6 changes: 3 additions & 3 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
semantic_model_lookup: SemanticModelLookup,
nodes_available_for_joins: Sequence[DataflowPlanNode],
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver,
time_spine_nodes: Sequence[MetricTimeDimensionTransformNode],
time_spine_metric_time_nodes: Sequence[MetricTimeDimensionTransformNode],
) -> None:
"""Initializer.

Expand All @@ -186,7 +186,7 @@ def __init__(
self._node_data_set_resolver = node_data_set_resolver
self._partition_resolver = PartitionJoinResolver(self._semantic_model_lookup)
self._join_evaluator = SemanticModelJoinEvaluator(self._semantic_model_lookup)
self._time_spine_nodes = time_spine_nodes
self._time_spine_metric_time_nodes = time_spine_metric_time_nodes

def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
self,
Expand All @@ -205,7 +205,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
linkable_specs_in_right_node = data_set_in_right_node.instance_set.spec_set.linkable_specs

# If right node is time spine source node, use cross join.
if right_node in self._time_spine_nodes:
if right_node in self._time_spine_metric_time_nodes:
satisfiable_metric_time_specs = [
spec for spec in linkable_specs_in_right_node if spec in needed_linkable_specs
]
Expand Down
28 changes: 18 additions & 10 deletions metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,23 @@ class SourceNodeSet:
# Semantic models are 1:1 mapped to a ReadSqlSourceNode.
source_nodes_for_group_by_item_queries: Tuple[DataflowPlanNode, ...]

# Provides the time spines.
time_spine_nodes: Mapping[TimeGranularity, MetricTimeDimensionTransformNode]
# Provides time spines that can be used to satisfy time spine joins.
time_spine_read_nodes: Mapping[TimeGranularity, ReadSqlSourceNode]

# Provides time spines that can be used to satisfy metric_time without metrics.
time_spine_metric_time_nodes: Mapping[TimeGranularity, MetricTimeDimensionTransformNode]

@property
def all_nodes(self) -> Sequence[DataflowPlanNode]: # noqa: D102
return (
self.source_nodes_for_metric_queries
+ self.source_nodes_for_group_by_item_queries
+ self.time_spine_nodes_tuple
+ self.time_spine_metric_time_nodes_tuple
)

@property
def time_spine_nodes_tuple(self) -> Tuple[MetricTimeDimensionTransformNode, ...]: # noqa: D102
return tuple(self.time_spine_nodes.values())
def time_spine_metric_time_nodes_tuple(self) -> Tuple[MetricTimeDimensionTransformNode, ...]: # noqa: D102
return tuple(self.time_spine_metric_time_nodes.values())


class SourceNodeBuilder:
Expand All @@ -65,11 +68,15 @@ def __init__( # noqa: D107
self.time_spine_sources = TimeSpineSource.build_standard_time_spine_sources(
semantic_manifest_lookup.semantic_manifest
)
self._time_spine_source_nodes = {}
for granularity, time_spine_source in self.time_spine_sources.items():

self._time_spine_read_nodes = {}
self._time_spine_metric_time_nodes = {}
for base_granularity, time_spine_source in self.time_spine_sources.items():
data_set = data_set_converter.build_time_spine_source_data_set(time_spine_source)
self._time_spine_source_nodes[granularity] = MetricTimeDimensionTransformNode.create(
parent_node=ReadSqlSourceNode.create(data_set),
read_node = ReadSqlSourceNode.create(data_set)
self._time_spine_read_nodes[base_granularity] = read_node
self._time_spine_metric_time_nodes[base_granularity] = MetricTimeDimensionTransformNode.create(
parent_node=read_node,
aggregation_time_dimension_reference=TimeDimensionReference(time_spine_source.base_column),
)

Expand Down Expand Up @@ -103,7 +110,8 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So
source_nodes_for_metric_queries.append(metric_time_transform_node)

return SourceNodeSet(
time_spine_nodes=self._time_spine_source_nodes,
time_spine_metric_time_nodes=self._time_spine_metric_time_nodes,
time_spine_read_nodes=self._time_spine_read_nodes,
source_nodes_for_group_by_item_queries=tuple(group_by_item_source_nodes),
source_nodes_for_metric_queries=tuple(source_nodes_for_metric_queries),
)
Expand Down
Loading
Loading