Skip to content

Commit

Permalink
joinp: add --sql-filter option
Browse files Browse the repository at this point in the history
to allow users to post-process a join result with a SQL expression.

Primarily, to select columns and filter results.
  • Loading branch information
jqnatividad committed Sep 6, 2023
1 parent 43ff5b8 commit 02a3cd5
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
30 changes: 24 additions & 6 deletions src/cmd/joinp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ joinp options:
(e.g. 2022-02-29 -> 2022-02-28) instead of erroring.
OUTPUT FORMAT OPTIONS:
--sql-filter <SQL> The SQL expression to apply against the join result.
Ordinarily used to select columns and filter rows from
the join result. Be sure to select from the "join_result"
table when formulating the SQL expression.
(e.g. "select c1, c2 as colname from join_result where c2 > 20")
--datetime-format <fmt> The datetime format to use writing datetimes.
See https://docs.rs/chrono/latest/chrono/format/strftime/index.html
for the list of valid format specifiers.
Expand All @@ -151,7 +156,7 @@ joinp options:
--float-precision <arg> The number of digits of precision to use when writing floats.
(default: 6)
--null-value <arg> The string to use when writing null values.
(default: <empty string>)
(default: <empty string>)
Common options:
-h, --help Display this message
Expand All @@ -172,9 +177,10 @@ use polars::{
datatypes::AnyValue,
frame::hash_join::{JoinType, JoinValidation},
prelude::{
AsOfOptions, AsofStrategy, CsvWriter, LazyCsvReader, LazyFileListReader, LazyFrame,
SerWriter, SortOptions,
AsOfOptions, AsofStrategy, CsvWriter, IntoLazy, LazyCsvReader, LazyFileListReader,
LazyFrame, SerWriter, SortOptions,
},
sql::SQLContext,
};
use serde::Deserialize;
use smartstring;
Expand Down Expand Up @@ -207,6 +213,7 @@ struct Args {
flag_right_by: Option<String>,
flag_strategy: Option<String>,
flag_tolerance: Option<String>,
flag_sql_filter: Option<String>,
flag_datetime_format: Option<String>,
flag_date_format: Option<String>,
flag_time_format: Option<String>,
Expand Down Expand Up @@ -328,6 +335,7 @@ struct JoinStruct {
delim: u8,
streaming: bool,
no_optimizations: bool,
sql_filter: Option<String>,
datetime_format: Option<String>,
date_format: Option<String>,
time_format: Option<String>,
Expand Down Expand Up @@ -384,7 +392,7 @@ impl JoinStruct {
};
log::debug!("Optimization state: {optimization_state:?}");

let mut join_results = if jointype == JoinType::Cross {
let join_results = if jointype == JoinType::Cross {
self.left_lf
.with_optimizations(optimization_state)
.join_builder()
Expand Down Expand Up @@ -414,6 +422,15 @@ impl JoinStruct {
.collect()?
};

let mut results_df = if let Some(sql_filter) = &self.sql_filter {
let mut ctx = SQLContext::new();
ctx.register("join_result", join_results.lazy());
ctx.execute(sql_filter)
.and_then(polars::prelude::LazyFrame::collect)?
} else {
join_results
};

// no need to use buffered writer here, as CsvWriter already does that
let mut out_writer = match self.output {
Some(output_file) => {
Expand All @@ -424,7 +441,7 @@ impl JoinStruct {
};

// shape is the number of rows and columns
let join_shape = join_results.shape();
let join_shape = results_df.shape();

CsvWriter::new(&mut out_writer)
.has_header(true)
Expand All @@ -434,7 +451,7 @@ impl JoinStruct {
.with_time_format(self.time_format)
.with_float_precision(self.float_precision)
.with_null_value(self.null_value)
.finish(&mut join_results)?;
.finish(&mut results_df)?;

Ok(join_shape)
}
Expand Down Expand Up @@ -499,6 +516,7 @@ impl Args {
delim,
streaming: self.flag_streaming,
no_optimizations: self.flag_no_optimizations,
sql_filter: self.flag_sql_filter.clone(),
datetime_format: self.flag_datetime_format.clone(),
date_format: self.flag_date_format.clone(),
time_format: self.flag_time_format.clone(),
Expand Down
41 changes: 41 additions & 0 deletions tests/test_joinp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,44 @@ fn joinp_asof_date_diffcolnames() {
];
assert_eq!(got, expected);
}

#[test]
fn joinp_asof_date_diffcolnames_sqlfilter() {
let wrk = Workdir::new("join_asof_date_diffcolnames_sqlfilter");
wrk.create(
"gdp.csv",
vec![
svec!["gdp_date", "gdp"],
svec!["2016-01-01", "4164"],
svec!["2017-01-01", "4411"],
svec!["2018-01-01", "4566"],
svec!["2019-01-01", "4696"],
],
);
wrk.create(
"population.csv",
vec![
svec!["pop_date", "population"],
svec!["2016-05-12", "82.19"],
svec!["2017-05-12", "82.66"],
svec!["2018-05-12", "83.12"],
svec!["2019-05-12", "83.52"],
],
);

let mut cmd = wrk.command("joinp");
cmd.arg("--asof")
.args(["pop_date", "population.csv", "gdp_date", "gdp.csv"])
.args([
"--sql-filter",
"select pop_date, gdp from join_result where gdp > 4500",
]);

let got: Vec<Vec<String>> = wrk.read_stdout(&mut cmd);
let expected = vec![
svec!["pop_date", "gdp"],
svec!["2018-05-12", "4566"],
svec!["2019-05-12", "4696"],
];
assert_eq!(got, expected);
}

0 comments on commit 02a3cd5

Please sign in to comment.