Skip to content

Commit

Permalink
refactor: dataframe join params (#912)
Browse files Browse the repository at this point in the history
* refactor: dataframe join params

* chore: add description for on params

* fix type

* chore: change join param

* chore: update join params in tpch

* oops

* chore: final change

* Add support for join_keys as a positional argument

---------

Co-authored-by: Tim Saucer <[email protected]>
  • Loading branch information
ion-elgreco and timsaucer authored Nov 8, 2024
1 parent cbe28cb commit 4a6c4d1
Show file tree
Hide file tree
Showing 25 changed files with 240 additions and 85 deletions.
10 changes: 5 additions & 5 deletions docs/source/user-guide/common-operations/joins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ will be included in the resulting DataFrame.

.. ipython:: python
left.join(right, join_keys=(["customer_id"], ["id"]), how="inner")
left.join(right, left_on="customer_id", right_on="id", how="inner")
The parameter ``join_keys`` specifies the columns from the left DataFrame and right DataFrame that contains the values
that should match.
Expand All @@ -70,7 +70,7 @@ values for the corresponding columns.

.. ipython:: python
left.join(right, join_keys=(["customer_id"], ["id"]), how="left")
left.join(right, left_on="customer_id", right_on="id", how="left")
Full Join
---------
Expand All @@ -80,7 +80,7 @@ is no match. Unmatched rows will have null values.

.. ipython:: python
left.join(right, join_keys=(["customer_id"], ["id"]), how="full")
left.join(right, left_on="customer_id", right_on="id", how="full")
Left Semi Join
--------------
Expand All @@ -90,7 +90,7 @@ omitting duplicates with multiple matches in the right table.

.. ipython:: python
left.join(right, join_keys=(["customer_id"], ["id"]), how="semi")
left.join(right, left_on="customer_id", right_on="id", how="semi")
Left Anti Join
--------------
Expand All @@ -101,4 +101,4 @@ the right table.

.. ipython:: python
left.join(right, join_keys=(["customer_id"], ["id"]), how="anti")
left.join(right, left_on="customer_id", right_on="id", how="anti")
6 changes: 3 additions & 3 deletions examples/tpch/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
from importlib import import_module
import pyarrow as pa
from datafusion import col, lit, functions as F
from datafusion import DataFrame, col, lit, functions as F
from util import get_answer_file


Expand Down Expand Up @@ -94,7 +94,7 @@ def check_q17(df):
)
def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
module = import_module(query_code)
df = module.df
df: DataFrame = module.df

# Treat q17 as a special case. The answer file does not match the spec.
# Running at scale factor 1, we have manually verified this result does
Expand All @@ -121,5 +121,5 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):

cols = list(read_schema.names)

assert df.join(df_expected, (cols, cols), "anti").count() == 0
assert df.join(df_expected, on=cols, how="anti").count() == 0
assert df.count() == df_expected.count()
12 changes: 8 additions & 4 deletions examples/tpch/q02_minimum_cost_supplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,20 @@
# Now that we have the region, find suppliers in that region. Suppliers are tied to their nation
# and nations are tied to the region.

df_nation = df_nation.join(df_region, (["n_regionkey"], ["r_regionkey"]), how="inner")
df_nation = df_nation.join(
df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner"
)
df_supplier = df_supplier.join(
df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner"
df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner"
)

# Now that we know who the potential suppliers are for the part, we can limit out part
# supplies table down. We can further join down to the specific parts we've identified
# as matching the request

df = df_partsupp.join(df_supplier, (["ps_suppkey"], ["s_suppkey"]), how="inner")
df = df_partsupp.join(
df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner"
)

# Locate the minimum cost across all suppliers. There are multiple ways you could do this,
# but one way is to create a window function across all suppliers, find the minimum, and
Expand All @@ -111,7 +115,7 @@

df = df.filter(col("min_cost") == col("ps_supplycost"))

df = df.join(df_part, (["ps_partkey"], ["p_partkey"]), how="inner")
df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner")

# From the problem statement, these are the values we wish to output

Expand Down
6 changes: 3 additions & 3 deletions examples/tpch/q03_shipping_priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@

# Join all 3 dataframes

df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="inner").join(
df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner"
)
df = df_customer.join(
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")

# Compute the revenue

Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q04_order_priority_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@
)

# Perform the join to find only orders for which there are lineitems outside of expected range
df = df_orders.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner")
df = df_orders.join(
df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner"
)

# Based on priority, find the number of entries
df = df.aggregate(
Expand Down
13 changes: 8 additions & 5 deletions examples/tpch/q05_local_supplier_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,18 @@
# Join all the dataframes

df = (
df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="inner")
.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner")
df_customer.join(
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
)
.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")
.join(
df_supplier,
(["l_suppkey", "c_nationkey"], ["s_suppkey", "s_nationkey"]),
left_on=["l_suppkey", "c_nationkey"],
right_on=["s_suppkey", "s_nationkey"],
how="inner",
)
.join(df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner")
.join(df_region, (["n_regionkey"], ["r_regionkey"]), how="inner")
.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner")
.join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner")
)

# Compute the final result
Expand Down
12 changes: 7 additions & 5 deletions examples/tpch/q07_volume_shipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,22 @@

# Limit suppliers to either nation
df_supplier = df_supplier.join(
df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner"
df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner"
).select(col("s_suppkey"), col("n_name").alias("supp_nation"))

# Limit customers to either nation
df_customer = df_customer.join(
df_nation, (["c_nationkey"], ["n_nationkey"]), how="inner"
df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner"
).select(col("c_custkey"), col("n_name").alias("cust_nation"))

# Join up all the data frames from line items, and make sure the supplier and customer are in
# different nations.
df = (
df_lineitem.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner")
.join(df_customer, (["o_custkey"], ["c_custkey"]), how="inner")
.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), how="inner")
df_lineitem.join(
df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner"
)
.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner")
.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner")
.filter(col("cust_nation") != col("supp_nation"))
)

Expand Down
14 changes: 7 additions & 7 deletions examples/tpch/q08_market_share.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,27 @@

# After this join we have all of the possible sales nations
df_regional_customers = df_regional_customers.join(
df_nation, (["r_regionkey"], ["n_regionkey"]), how="inner"
df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner"
)

# Now find the possible customers
df_regional_customers = df_regional_customers.join(
df_customer, (["n_nationkey"], ["c_nationkey"]), how="inner"
df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner"
)

# Next find orders for these customers
df_regional_customers = df_regional_customers.join(
df_orders, (["c_custkey"], ["o_custkey"]), how="inner"
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
)

# Find all line items from these orders
df_regional_customers = df_regional_customers.join(
df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner"
df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner"
)

# Limit to the part of interest
df_regional_customers = df_regional_customers.join(
df_part, (["l_partkey"], ["p_partkey"]), how="inner"
df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner"
)

# Compute the volume for each line item
Expand All @@ -126,7 +126,7 @@

# Determine the suppliers by the limited nation key we have in our single row df above
df_national_suppliers = df_national_suppliers.join(
df_supplier, (["n_nationkey"], ["s_nationkey"]), how="inner"
df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner"
)

# When we join to the customer dataframe, we don't want to confuse other columns, so only
Expand All @@ -141,7 +141,7 @@
# column only from suppliers in the nation we are evaluating.

df = df_regional_customers.join(
df_national_suppliers, (["l_suppkey"], ["s_suppkey"]), how="left"
df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left"
)

# Use a case statement to compute the volume sold by suppliers in the nation of interest
Expand Down
13 changes: 8 additions & 5 deletions examples/tpch/q09_product_type_profit_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,16 @@
df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0))

# We have a series of joins that get us to limit down to the line items we need
df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]), how="inner")
df = df.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), how="inner")
df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner")
df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner")
df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner")
df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner")
df = df.join(
df_partsupp, (["l_suppkey", "l_partkey"], ["ps_suppkey", "ps_partkey"]), how="inner"
df_partsupp,
left_on=["l_suppkey", "l_partkey"],
right_on=["ps_suppkey", "ps_partkey"],
how="inner",
)
df = df.join(df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner")
df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner")

# Compute the intermediate values and limit down to the expressions we need
df = df.select(
Expand Down
6 changes: 3 additions & 3 deletions examples/tpch/q10_returned_item_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
col("o_orderdate") < date_start_of_quarter + interval_one_quarter
)

df = df.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner")
df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")

# Compute the revenue
df = df.aggregate(
Expand All @@ -83,8 +83,8 @@
)

# Now join in the customer data
df = df.join(df_customer, (["o_custkey"], ["c_custkey"]), how="inner")
df = df.join(df_nation, (["c_nationkey"], ["n_nationkey"]), how="inner")
df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner")
df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner")

# These are the columns the problem statement requires
df = df.select(
Expand Down
6 changes: 4 additions & 2 deletions examples/tpch/q11_important_stock_identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@

# Find part supplies of within this target nation

df = df_nation.join(df_supplier, (["n_nationkey"], ["s_nationkey"]), how="inner")
df = df_nation.join(
df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner"
)

df = df.join(df_partsupp, (["s_suppkey"], ["ps_suppkey"]), how="inner")
df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner")


# Compute the value of individual parts
Expand Down
2 changes: 1 addition & 1 deletion examples/tpch/q12_ship_mode_order_priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@


# We need order priority, so join order df to line item
df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner")
df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner")

# Restrict to line items we care about based on the problem statement.
df = df.filter(col("l_commitdate") < col("l_receiptdate"))
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q13_customer_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
)

# Since we may have customers with no orders we must do a left join
df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="left")
df = df_customer.join(
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left"
)

# Find the number of orders for each customer
df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")])
Expand Down
4 changes: 3 additions & 1 deletion examples/tpch/q14_promotion_effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@
)

# Left join so we can sum up the promo parts different from other parts
df = df_lineitem.join(df_part, (["l_partkey"], ["p_partkey"]), "left")
df = df_lineitem.join(
df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left"
)

# Make a factor of 1.0 if it is a promotion, 0.0 otherwise
df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0)))
Expand Down
2 changes: 1 addition & 1 deletion examples/tpch/q15_top_supplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

# Now that we know the supplier(s) with maximum revenue, get the rest of their information
# from the supplier table
df = df.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), "inner")
df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner")

# Return only the columns requested
df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue")
Expand Down
6 changes: 4 additions & 2 deletions examples/tpch/q16_part_supplier_relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

# Remove unwanted suppliers
df_partsupp = df_partsupp.join(
df_unwanted_suppliers, (["ps_suppkey"], ["s_suppkey"]), "anti"
df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti"
)

# Select the parts we are interested in
Expand All @@ -73,7 +73,9 @@
p_sizes = F.make_array(*[lit(s).cast(pa.int32()) for s in SIZES_OF_INTEREST])
df_part = df_part.filter(~F.array_position(p_sizes, col("p_size")).is_null())

df = df_part.join(df_partsupp, (["p_partkey"], ["ps_partkey"]), "inner")
df = df_part.join(
df_partsupp, left_on=["p_partkey"], right_on=["ps_partkey"], how="inner"
)

df = df.select("p_brand", "p_type", "p_size", "ps_suppkey").distinct()

Expand Down
2 changes: 1 addition & 1 deletion examples/tpch/q17_small_quantity_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)

# Combine data
df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]), "inner")
df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner")

# Find the average quantity
window_frame = WindowFrame("rows", None, None)
Expand Down
4 changes: 2 additions & 2 deletions examples/tpch/q18_large_volume_customer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@

# We've identified the orders of interest, now join the additional data
# we are required to report on
df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), "inner")
df = df.join(df_customer, (["o_custkey"], ["c_custkey"]), "inner")
df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner")
df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner")

df = df.select(
"c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity"
Expand Down
2 changes: 1 addition & 1 deletion examples/tpch/q19_discounted_revenue.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
(col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG"))
)

df = df.join(df_part, (["l_partkey"], ["p_partkey"]), "inner")
df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner")


# Create the user defined function (UDF) definition that does the work
Expand Down
Loading

0 comments on commit 4a6c4d1

Please sign in to comment.