From 2b15ad16f3516f29d6540c974170591a6c085478 Mon Sep 17 00:00:00 2001 From: zjregee Date: Tue, 31 Dec 2024 22:22:10 +0800 Subject: [PATCH] consolidate dataframe_subquery.rs into dataframe.rs (#13950) --- datafusion-examples/README.md | 2 +- datafusion-examples/examples/dataframe.rs | 91 ++++++++++++++ .../examples/dataframe_subquery.rs | 118 ------------------ 3 files changed, 92 insertions(+), 119 deletions(-) delete mode 100644 datafusion-examples/examples/dataframe_subquery.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 3ec008a6026d..23cf8830e36d 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -57,7 +57,7 @@ cargo run --example dataframe - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`custom_file_format.rs`](examples/custom_file_format.rs): Write data to a custom file format - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 -- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame API against parquet files, csv files, and in-memory data. Also demonstrates the various methods to write out a DataFrame to a table, parquet file, csv file, and json file. +- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries. Also demonstrates the various methods to write out a DataFrame to a table, parquet file, csv file, and json file. - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde - [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify, analyze and coerce `Expr`s - [`file_stream_provider.rs`](examples/file_stream_provider.rs): Run a query on `FileStreamProvider` which implements `StreamProvider` for reading and writing to arbitrary stream sources / sinks. diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 90d7d778ea5c..91d62135b913 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -19,10 +19,13 @@ use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::Result; +use datafusion::functions_aggregate::average::avg; +use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; use datafusion_common::config::CsvOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; use std::fs::File; use std::io::Write; use std::sync::Arc; @@ -44,7 +47,14 @@ use tempfile::tempdir; /// /// * [write_out]: write out a DataFrame to a table, parquet file, csv file, or json file /// +/// # Executing subqueries +/// +/// * [where_scalar_subquery]: execute a scalar subquery +/// * [where_in_subquery]: execute a subquery with an IN clause +/// * [where_exist_subquery]: execute a subquery with an EXISTS clause +/// /// # Querying data +/// /// * [query_to_date]: execute queries against parquet files #[tokio::main] async fn main() -> Result<()> { @@ -55,6 +65,11 @@ async fn main() -> Result<()> { read_memory(&ctx).await?; write_out(&ctx).await?; query_to_date().await?; + register_aggregate_test_data("t1", &ctx).await?; + register_aggregate_test_data("t2", &ctx).await?; + where_scalar_subquery(&ctx).await?; + where_in_subquery(&ctx).await?; + where_exist_subquery(&ctx).await?; Ok(()) } @@ -250,3 +265,79 @@ async fn query_to_date() -> Result<()> { Ok(()) } + +/// Use the DataFrame API to execute the following subquery: +/// select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; +async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { + ctx.table("t1") + .await? + .filter( + scalar_subquery(Arc::new( + ctx.table("t2") + .await? + .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? + .aggregate(vec![], vec![avg(col("t2.c2"))])? + .select(vec![avg(col("t2.c2"))])? + .into_unoptimized_plan(), + )) + .gt(lit(0u8)), + )? + .select(vec![col("t1.c1"), col("t1.c2")])? + .limit(0, Some(3))? + .show() + .await?; + Ok(()) +} + +/// Use the DataFrame API to execute the following subquery: +/// select t1.c1, t1.c2 from t1 where t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; +async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { + ctx.table("t1") + .await? + .filter(in_subquery( + col("t1.c2"), + Arc::new( + ctx.table("t2") + .await? + .filter(col("t2.c1").gt(lit(ScalarValue::UInt8(Some(0)))))? + .aggregate(vec![], vec![max(col("t2.c2"))])? + .select(vec![max(col("t2.c2"))])? + .into_unoptimized_plan(), + ), + ))? + .select(vec![col("t1.c1"), col("t1.c2")])? + .limit(0, Some(3))? + .show() + .await?; + Ok(()) +} + +/// Use the DataFrame API to execute the following subquery: +/// select t1.c1, t1.c2 from t1 where exists (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; +async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { + ctx.table("t1") + .await? + .filter(exists(Arc::new( + ctx.table("t2") + .await? + .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? + .select(vec![col("t2.c2")])? + .into_unoptimized_plan(), + )))? + .select(vec![col("t1.c1"), col("t1.c2")])? + .limit(0, Some(3))? + .show() + .await?; + Ok(()) +} + +async fn register_aggregate_test_data(name: &str, ctx: &SessionContext) -> Result<()> { + let testdata = datafusion::test_util::arrow_test_data(); + ctx.register_csv( + name, + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::default(), + ) + .await?; + Ok(()) +} diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs deleted file mode 100644 index 3e3d0c1b5a84..000000000000 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ /dev/null @@ -1,118 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow_schema::DataType; -use std::sync::Arc; - -use datafusion::error::Result; -use datafusion::functions_aggregate::average::avg; -use datafusion::functions_aggregate::min_max::max; -use datafusion::prelude::*; -use datafusion::test_util::arrow_test_data; -use datafusion_common::ScalarValue; - -/// This example demonstrates how to use the DataFrame API to create a subquery. -#[tokio::main] -async fn main() -> Result<()> { - let ctx = SessionContext::new(); - register_aggregate_test_data("t1", &ctx).await?; - register_aggregate_test_data("t2", &ctx).await?; - - where_scalar_subquery(&ctx).await?; - - where_in_subquery(&ctx).await?; - - where_exist_subquery(&ctx).await?; - - Ok(()) -} - -//select c1,c2 from t1 where (select avg(t2.c2) from t2 where t1.c1 = t2.c1)>0 limit 3; -async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { - ctx.table("t1") - .await? - .filter( - scalar_subquery(Arc::new( - ctx.table("t2") - .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .aggregate(vec![], vec![avg(col("t2.c2"))])? - .select(vec![avg(col("t2.c2"))])? - .into_unoptimized_plan(), - )) - .gt(lit(0u8)), - )? - .select(vec![col("t1.c1"), col("t1.c2")])? - .limit(0, Some(3))? - .show() - .await?; - Ok(()) -} - -//SELECT t1.c1, t1.c2 FROM t1 WHERE t1.c2 in (select max(t2.c2) from t2 where t2.c1 > 0 ) limit 3; -async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { - ctx.table("t1") - .await? - .filter(in_subquery( - col("t1.c2"), - Arc::new( - ctx.table("t2") - .await? - .filter(col("t2.c1").gt(lit(ScalarValue::UInt8(Some(0)))))? - .aggregate(vec![], vec![max(col("t2.c2"))])? - .select(vec![max(col("t2.c2"))])? - .into_unoptimized_plan(), - ), - ))? - .select(vec![col("t1.c1"), col("t1.c2")])? - .limit(0, Some(3))? - .show() - .await?; - Ok(()) -} - -//SELECT t1.c1, t1.c2 FROM t1 WHERE EXISTS (select t2.c2 from t2 where t1.c1 = t2.c1) limit 3; -async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { - ctx.table("t1") - .await? - .filter(exists(Arc::new( - ctx.table("t2") - .await? - .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? - .select(vec![col("t2.c2")])? - .into_unoptimized_plan(), - )))? - .select(vec![col("t1.c1"), col("t1.c2")])? - .limit(0, Some(3))? - .show() - .await?; - Ok(()) -} - -pub async fn register_aggregate_test_data( - name: &str, - ctx: &SessionContext, -) -> Result<()> { - let testdata = arrow_test_data(); - ctx.register_csv( - name, - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::default(), - ) - .await?; - Ok(()) -}