Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: union ops #3872

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,7 @@ class LogicalPlanBuilder:
suffix: str | None = None,
) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def union(self, other: LogicalPlanBuilder, quantifier: str) -> LogicalPlanBuilder: ...
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
def except_(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ...
Expand Down
140 changes: 135 additions & 5 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2344,9 +2344,9 @@ def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any)
DataFrame: Transformed DataFrame.
"""
result = func(self, *args, **kwargs)
assert isinstance(result, DataFrame), (
f"Func returned an instance of type [{type(result)}], " "should have been DataFrame."
)
assert isinstance(
result, DataFrame
), f"Func returned an instance of type [{type(result)}], should have been DataFrame."
return result

def _agg(
Expand Down Expand Up @@ -2638,7 +2638,11 @@ def groupby(self, *group_by: ManyColumnsInputType) -> "GroupedDataFrame":
>>> import daft
>>> from daft import col
>>> df = daft.from_pydict(
... {"pet": ["cat", "dog", "dog", "cat"], "age": [1, 2, 3, 4], "name": ["Alex", "Jordan", "Sam", "Riley"]}
... {
... "pet": ["cat", "dog", "dog", "cat"],
... "age": [1, 2, 3, 4],
... "name": ["Alex", "Jordan", "Sam", "Riley"],
... }
... )
>>> grouped_df = df.groupby("pet").agg(
... col("age").min().alias("min_age"),
Expand Down Expand Up @@ -2727,6 +2731,128 @@ def pivot(
builder = self._builder.pivot(group_by_expr, pivot_col_expr, value_col_expr, agg_expr, names)
return DataFrame(builder)

@DataframePublicAPI
def union(self, other: "DataFrame") -> "DataFrame":
"""Returns the distinct union of two DataFrames.

Example:
>>> import daft
>>> df1 = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> df2 = daft.from_pydict({"x": [3, 4, 5], "y": [6, 7, 8]})
>>> df1.union(df2).sort("x").show()
╭───────┬───────╮
│ x ┆ y │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 1 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 5 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 6 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 4 ┆ 7 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 5 ┆ 8 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 5 of 5 rows)
"""
builder = self._builder.union(other._builder, None)
return DataFrame(builder)

@DataframePublicAPI
def union_all(self, other: "DataFrame") -> "DataFrame":
"""Returns the union of two DataFrames, including duplicates.

Example:
>>> import daft
>>> df1 = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> df2 = daft.from_pydict({"x": [3, 2, 1], "y": [6, 5, 4]})
>>> df1.union_all(df2).sort("x").show()
╭───────┬───────╮
│ x ┆ y │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 1 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 1 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 5 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 5 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 6 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 6 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 6 of 6 rows)
"""
builder = self._builder.union(other._builder, "all")
return DataFrame(builder)

@DataframePublicAPI
def union_by_name(self, other: "DataFrame") -> "DataFrame":
"""Returns the distinct union by name.

Example:
>>> import daft
>>> df1 = daft.from_pydict({"x": [1, 2], "y": [4, 5], "w": [9, 10]})
>>> df2 = daft.from_pydict({"y": [6, 7], "z": ["a", "b"]})
>>> df1.union_by_name(df2).sort("y").show()
╭───────┬───────┬───────┬──────╮
│ x ┆ y ┆ w ┆ z │
│ --- ┆ --- ┆ --- ┆ --- │
│ Int64 ┆ Int64 ┆ Int64 ┆ Utf8 │
╞═══════╪═══════╪═══════╪══════╡
│ 1 ┆ 4 ┆ 9 ┆ None │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ 2 ┆ 5 ┆ 10 ┆ None │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ None ┆ 6 ┆ None ┆ a │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ None ┆ 7 ┆ None ┆ b │
╰───────┴───────┴───────┴──────╯
<BLANKLINE>
(Showing first 4 of 4 rows)
"""
builder = self._builder.union(other._builder, "by_name")
return DataFrame(builder)

@DataframePublicAPI
def union_all_by_name(self, other: "DataFrame") -> "DataFrame":
"""Returns the union of two DataFrames, including duplicates, with columns matched by name.

Example:
>>> import daft
>>> df1 = daft.from_pydict({"x": [1, 2], "y": [4, 5], "w": [9, 10]})
>>> df2 = daft.from_pydict({"y": [6, 6, 7, 7], "z": ["a", "a", "b", "b"]})
>>> df1.union_all_by_name(df2).sort("y").show()
╭───────┬───────┬───────┬──────╮
│ x ┆ y ┆ w ┆ z │
│ --- ┆ --- ┆ --- ┆ --- │
│ Int64 ┆ Int64 ┆ Int64 ┆ Utf8 │
╞═══════╪═══════╪═══════╪══════╡
│ 1 ┆ 4 ┆ 9 ┆ None │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ 2 ┆ 5 ┆ 10 ┆ None │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ None ┆ 6 ┆ None ┆ a │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ None ┆ 6 ┆ None ┆ a │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ None ┆ 7 ┆ None ┆ b │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ None ┆ 7 ┆ None ┆ b │
╰───────┴───────┴───────┴──────╯
<BLANKLINE>
(Showing first 6 of 6 rows)
"""
builder = self._builder.union(other._builder, "all_by_name")
return DataFrame(builder)

@DataframePublicAPI
def intersect(self, other: "DataFrame") -> "DataFrame":
"""Returns the intersection of two DataFrames.
Expand Down Expand Up @@ -3422,7 +3548,11 @@ def agg(self, *to_agg: Union[Expression, Iterable[Expression]]) -> "DataFrame":
>>> import daft
>>> from daft import col
>>> df = daft.from_pydict(
... {"pet": ["cat", "dog", "dog", "cat"], "age": [1, 2, 3, 4], "name": ["Alex", "Jordan", "Sam", "Riley"]}
... {
... "pet": ["cat", "dog", "dog", "cat"],
... "age": [1, 2, 3, 4],
... "name": ["Alex", "Jordan", "Sam", "Riley"],
... }
... )
>>> grouped_df = df.groupby("pet").agg(
... col("age").min().alias("min_age"),
Expand Down
8 changes: 7 additions & 1 deletion daft/logical/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Literal

from daft.context import get_context
from daft.daft import (
Expand Down Expand Up @@ -292,6 +292,12 @@ def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: # type: igno
builder = self._builder.concat(other._builder)
return LogicalPlanBuilder(builder)

def union(
self, other: LogicalPlanBuilder, quantifier: Literal["all", "by_name", "all_by_name"] | None
) -> LogicalPlanBuilder:
builder = self._builder.union(other._builder, quantifier)
return LogicalPlanBuilder(builder)

def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.intersect(other._builder, False)
return LogicalPlanBuilder(builder)
Expand Down
15 changes: 13 additions & 2 deletions src/daft-connect/src/spark_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
use arrow2::io::ipc::read::{read_stream_metadata, StreamReader, StreamState};
use daft_core::{join::JoinSide, series::Series};
use daft_dsl::{unresolved_col, Column, Expr, ExprRef, Operator, PlanRef, UnresolvedColumn};
use daft_logical_plan::{JoinOptions, JoinType, LogicalPlanBuilder, PyLogicalPlanBuilder};
use daft_logical_plan::{
ops::SetQuantifier, JoinOptions, JoinType, LogicalPlanBuilder, PyLogicalPlanBuilder,
};
use daft_micropartition::{self, python::PyMicroPartition, MicroPartition};
use daft_recordbatch::RecordBatch;
use daft_scan::builder::{delta_scan, CsvScanBuilder, JsonScanBuilder, ParquetScanBuilder};
Expand Down Expand Up @@ -701,7 +703,16 @@
match set_op_type {
SetOpType::Except => left.except(&right, is_all),
SetOpType::Intersect => left.intersect(&right, is_all),
SetOpType::Union => left.union(&right, is_all),
SetOpType::Union => {
let set_quantifier = match (is_all, set_op.by_name) {
(true, Some(true)) => SetQuantifier::AllByName,
(true, Some(false)) | (true, None) => SetQuantifier::All,
(false, Some(true)) => SetQuantifier::DistinctByName,
(false, Some(false)) | (false, None) => SetQuantifier::Distinct,

Check warning on line 711 in src/daft-connect/src/spark_analyzer.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/spark_analyzer.rs#L710-L711

Added lines #L710 - L711 were not covered by tests
};

left.union(&right, set_quantifier)
}
SetOpType::Unspecified => {
invalid_argument_err!("SetOpType must be specified; got Unspecified")
}
Expand Down
22 changes: 19 additions & 3 deletions src/daft-logical-plan/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

use crate::{
logical_plan::{LogicalPlan, SubqueryAlias},
ops::{self, join::JoinOptions},
ops::{self, join::JoinOptions, SetQuantifier},
optimization::OptimizerBuilder,
partitioning::{
HashRepartitionConfig, IntoPartitionsConfig, RandomShuffleConfig, RepartitionSpec,
Expand Down Expand Up @@ -608,9 +608,9 @@
Ok(self.with_new_plan(logical_plan))
}

pub fn union(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
pub fn union(&self, other: &Self, set_quantifier: SetQuantifier) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Union::try_new(self.plan.clone(), other.plan.clone(), is_all)?
ops::Union::try_new(self.plan.clone(), other.plan.clone(), set_quantifier)?
.to_logical_plan()?;
Ok(self.with_new_plan(logical_plan))
}
Expand Down Expand Up @@ -1080,6 +1080,22 @@
Ok(self.builder.concat(&other.builder)?.into())
}

#[pyo3(signature = (other, quantifier=None))]
pub fn union(&self, other: &Self, quantifier: Option<String>) -> DaftResult<Self> {
let quantifier = match quantifier.map(|s| s.to_lowercase()).as_deref() {
Some("all") => SetQuantifier::All,
Some("all_by_name") => SetQuantifier::AllByName,
Some("by_name") => SetQuantifier::DistinctByName,
None => SetQuantifier::Distinct,
_ => {
return Err(DaftError::InternalError(
"Invalid set quantifier".to_string(),
))

Check warning on line 1093 in src/daft-logical-plan/src/builder/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-logical-plan/src/builder/mod.rs#L1091-L1093

Added lines #L1091 - L1093 were not covered by tests
}
};
Ok(self.builder.union(&other.builder, quantifier)?.into())
}

pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
Ok(self.builder.intersect(&other.builder, is_all)?.into())
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@
Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"),
Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()),
Self::Intersect(inner) => Self::Intersect(Intersect::try_new(input1.clone(), input2.clone(), inner.is_all).unwrap()),
Self::Union(inner) => Self::Union(Union::try_new(input1.clone(), input2.clone(), inner.is_all).unwrap()),
Self::Union(inner) => Self::Union(Union::try_new(input1.clone(), input2.clone(), inner.quantifier).unwrap()),

Check warning on line 395 in src/daft-logical-plan/src/logical_plan.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-logical-plan/src/logical_plan.rs#L395

Added line #L395 was not covered by tests
Self::Join(Join { left_on, right_on, null_equals_nulls, join_type, join_strategy, .. }) => Self::Join(Join::try_new(
input1.clone(),
input2.clone(),
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub use pivot::Pivot;
pub use project::Project;
pub use repartition::Repartition;
pub use sample::Sample;
pub use set_operations::{Except, Intersect, Union};
pub use set_operations::{Except, Intersect, SetQuantifier, Union};
pub use sink::Sink;
pub use sort::Sort;
pub use source::Source;
Expand Down
Loading
Loading