Skip to content

Commit

Permalink
fix(query): reuse connection to fix dictionary mysql flaky test (#17016)
Browse files Browse the repository at this point in the history
* fix(query): reuse connection to fix dictionary mysql flaky test

* mock mysql source use thread pool
  • Loading branch information
b41sh authored Dec 10, 2024
1 parent 8854bab commit 644f17b
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 55 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ mysql_common = "0.32.4"
quickcheck = "1.0"
sqllogictest = "0.21.0"
sqlparser = "0.50.0"
threadpool = "1.8"

[workspace.lints.rust]
async_fn_in_trait = "allow"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ use crate::sql::executor::physical_plans::AsyncFunctionDesc;
use crate::sql::plans::AsyncFunctionArgument;
use crate::sql::plans::DictGetFunctionArgument;
use crate::sql::plans::DictionarySource;
use crate::sql::plans::SqlSource;
use crate::sql::IndexType;

macro_rules! sqlx_fetch_optional {
Expand Down Expand Up @@ -160,20 +161,20 @@ macro_rules! fetch_all_rows_by_sqlx {

pub(crate) enum DictionaryOperator {
Redis(ConnectionManager),
Mysql((MySqlPool, String)),
Mysql(MySqlPool),
}

impl DictionaryOperator {
async fn dict_get(
&self,
value: &Value<AnyType>,
data_type: &DataType,
default_value: &Scalar,
dict_arg: &DictGetFunctionArgument,
) -> Result<Value<AnyType>> {
match self {
DictionaryOperator::Redis(connection) => match value {
Value::Scalar(scalar) => {
self.get_scalar_value_from_redis(scalar, connection, default_value)
self.get_scalar_value_from_redis(scalar, connection, &dict_arg.default_value)
.await
}
Value::Column(column) => {
Expand All @@ -192,24 +193,38 @@ impl DictionaryOperator {
validity,
data_type,
connection,
default_value,
&dict_arg.default_value,
)
.await
}
},
DictionaryOperator::Mysql((pool, sql)) => match value {
Value::Scalar(scalar) => {
let value = self
.get_scalar_value_from_mysql(scalar.as_ref(), data_type, pool, sql)
.await?
.unwrap_or(default_value.clone());
Ok(Value::Scalar(value))
}
Value::Column(column) => {
self.get_column_values_from_mysql(column, data_type, default_value, pool, sql)
DictionaryOperator::Mysql(pool) => {
let sql_source = dict_arg.dict_source.as_mysql().unwrap();
match value {
Value::Scalar(scalar) => {
let value = self
.get_scalar_value_from_mysql(
scalar.as_ref(),
data_type,
pool,
sql_source,
)
.await?
.unwrap_or(dict_arg.default_value.clone());
Ok(Value::Scalar(value))
}
Value::Column(column) => {
self.get_column_values_from_mysql(
column,
data_type,
&dict_arg.default_value,
pool,
sql_source,
)
.await
}
}
},
}
}
}

Expand Down Expand Up @@ -342,34 +357,42 @@ impl DictionaryOperator {
key: ScalarRef<'_>,
value_type: &DataType,
pool: &MySqlPool,
sql: &String,
sql_source: &SqlSource,
) -> Result<Option<Scalar>> {
if key == ScalarRef::Null {
return Ok(None);
}
let new_sql = format!("{} ({}) LIMIT 1", sql, self.format_key(key.clone()));

let sql = format!(
"SELECT {}, {} FROM {} WHERE {} = {} LIMIT 1",
sql_source.key_field,
sql_source.value_field,
sql_source.table,
sql_source.key_field,
self.format_key(key.clone())
);
let key_type = key.infer_data_type().remove_nullable();
match value_type.remove_nullable() {
DataType::Boolean => {
fetch_single_row_by_sqlx!(pool, new_sql, key_type, bool, Scalar::Boolean)
fetch_single_row_by_sqlx!(pool, sql, key_type, bool, Scalar::Boolean)
}
DataType::String => {
fetch_single_row_by_sqlx!(pool, new_sql, key_type, String, Scalar::String)
fetch_single_row_by_sqlx!(pool, sql, key_type, String, Scalar::String)
}
DataType::Number(num_ty) => {
with_integer_mapped_type!(|NUM_TYPE| match num_ty {
NumberDataType::NUM_TYPE => {
fetch_single_row_by_sqlx!(pool, new_sql, key_type, NUM_TYPE, |v| {
fetch_single_row_by_sqlx!(pool, sql, key_type, NUM_TYPE, |v| {
Scalar::Number(NUM_TYPE::upcast_scalar(v))
})
}
NumberDataType::Float32 => {
fetch_single_row_by_sqlx!(pool, new_sql, key_type, f32, |v: f32| {
fetch_single_row_by_sqlx!(pool, sql, key_type, f32, |v: f32| {
Scalar::Number(NumberScalar::Float32(v.into()))
})
}
NumberDataType::Float64 => {
fetch_single_row_by_sqlx!(pool, new_sql, key_type, f64, |v: f64| {
fetch_single_row_by_sqlx!(pool, sql, key_type, f64, |v: f64| {
Scalar::Number(NumberScalar::Float64(v.into()))
})
}
Expand All @@ -387,7 +410,7 @@ impl DictionaryOperator {
value_type: &DataType,
default_value: &Scalar,
pool: &MySqlPool,
sql: &String,
sql_source: &SqlSource,
) -> Result<Value<AnyType>> {
// todo: The current method formats the key as a string, which causes some performance overhead.
// The next step is to use the key's native types directly, such as bool, i32, etc.
Expand All @@ -408,12 +431,20 @@ impl DictionaryOperator {
}
return Ok(Value::Column(builder.build()));
}
let new_sql = format!("{} ({})", sql, self.format_keys(key_set));

let sql = format!(
"SELECT {}, {} FROM {} WHERE {} IN ({})",
sql_source.key_field,
sql_source.value_field,
sql_source.table,
sql_source.key_field,
self.format_keys(key_set)
);
let key_type = column.data_type().remove_nullable();
match value_type.remove_nullable() {
DataType::Boolean => {
let kv_pairs: HashMap<String, bool> =
fetch_all_rows_by_sqlx!(pool, &new_sql, key_type, bool, |k| self.format_key(k));
fetch_all_rows_by_sqlx!(pool, &sql, key_type, bool, |k| self.format_key(k));
for key in all_keys {
match kv_pairs.get(&key) {
Some(v) => builder.push(Scalar::Boolean(*v).as_ref()),
Expand All @@ -423,8 +454,7 @@ impl DictionaryOperator {
}
DataType::String => {
let kv_pairs: HashMap<String, String> =
fetch_all_rows_by_sqlx!(pool, &new_sql, key_type, String, |k| self
.format_key(k));
fetch_all_rows_by_sqlx!(pool, &sql, key_type, String, |k| self.format_key(k));
for key in all_keys {
match kv_pairs.get(&key) {
Some(v) => builder.push(Scalar::String(v.to_string()).as_ref()),
Expand All @@ -436,7 +466,7 @@ impl DictionaryOperator {
with_integer_mapped_type!(|NUM_TYPE| match num_ty {
NumberDataType::NUM_TYPE => {
let kv_pairs: HashMap<String, NUM_TYPE> =
fetch_all_rows_by_sqlx!(pool, &new_sql, key_type, NUM_TYPE, |k| self
fetch_all_rows_by_sqlx!(pool, &sql, key_type, NUM_TYPE, |k| self
.format_key(k));
for key in all_keys {
match kv_pairs.get(&key) {
Expand All @@ -448,7 +478,7 @@ impl DictionaryOperator {
}
NumberDataType::Float32 => {
let kv_pairs: HashMap<String, f32> =
fetch_all_rows_by_sqlx!(pool, &new_sql, key_type, f32, |k| self
fetch_all_rows_by_sqlx!(pool, &sql, key_type, f32, |k| self
.format_key(k));
for key in all_keys {
match kv_pairs.get(&key) {
Expand All @@ -461,7 +491,7 @@ impl DictionaryOperator {
}
NumberDataType::Float64 => {
let kv_pairs: HashMap<String, f64> =
fetch_all_rows_by_sqlx!(pool, &new_sql, key_type, f64, |k| self
fetch_all_rows_by_sqlx!(pool, &sql, key_type, f64, |k| self
.format_key(k));
for key in all_keys {
match kv_pairs.get(&key) {
Expand Down Expand Up @@ -511,11 +541,17 @@ impl TransformAsyncFunction {
pub(crate) fn init_operators(
async_func_descs: &[AsyncFunctionDesc],
) -> Result<BTreeMap<usize, Arc<DictionaryOperator>>> {
let mut operator_map: HashMap<String, Arc<DictionaryOperator>> = HashMap::new();
let mut operators = BTreeMap::new();
for (i, async_func_desc) in async_func_descs.iter().enumerate() {
if let AsyncFunctionArgument::DictGetFunction(dict_arg) = &async_func_desc.func_arg {
match &dict_arg.dict_source {
DictionarySource::Redis(redis_source) => {
let conn_url = format!("{}", redis_source);
if let Some(operator) = operator_map.get(&conn_url) {
operators.insert(i, operator.clone());
continue;
}
let connection_info = ConnectionInfo {
addr: redis::ConnectionAddr::Tcp(
redis_source.host.clone(),
Expand All @@ -532,20 +568,21 @@ impl TransformAsyncFunction {
let conn = databend_common_base::runtime::block_on(
ConnectionManager::new(client),
)?;
operators.insert(i, Arc::new(DictionaryOperator::Redis(conn)));
let operator = Arc::new(DictionaryOperator::Redis(conn));
operator_map.insert(conn_url, operator.clone());
operators.insert(i, operator);
}
DictionarySource::Mysql(sql_source) => {
if let Some(operator) = operator_map.get(&sql_source.connection_url) {
operators.insert(i, operator.clone());
continue;
}
let mysql_pool = databend_common_base::runtime::block_on(
sqlx::MySqlPool::connect(&sql_source.connection_url),
)?;
let sql = format!(
"SELECT {}, {} FROM {} WHERE {} in",
&sql_source.key_field,
&sql_source.value_field,
&sql_source.table,
&sql_source.key_field
);
operators.insert(i, Arc::new(DictionaryOperator::Mysql((mysql_pool, sql))));
let operator = Arc::new(DictionaryOperator::Mysql(mysql_pool));
operator_map.insert(sql_source.connection_url.clone(), operator.clone());
operators.insert(i, operator);
}
}
}
Expand All @@ -566,8 +603,7 @@ impl TransformAsyncFunction {
// only support one key field.
let arg_index = arg_indices[0];
let entry = data_block.get_by_offset(arg_index);
let default_value = dict_arg.default_value.clone();
let value = op.dict_get(&entry.value, data_type, &default_value).await?;
let value = op.dict_get(&entry.value, data_type, dict_arg).await?;
let entry = BlockEntry {
data_type: data_type.clone(),
value,
Expand Down
21 changes: 20 additions & 1 deletion src/query/sql/src/planner/plans/scalar_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

use std::collections::HashMap;
use std::fmt::Display;
use std::fmt::Formatter;
use std::hash::Hash;
use std::hash::Hasher;
use std::sync::Arc;
Expand Down Expand Up @@ -824,6 +826,23 @@ pub struct RedisSource {
pub db_index: Option<i64>,
}

impl Display for RedisSource {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "redis://")?;
if let Some(username) = &self.username {
write!(f, "{}:", username)?
}
if let Some(password) = &self.password {
write!(f, "{}@", password)?;
}
write!(f, "{}:{}", self.host, self.port)?;
if let Some(db_index) = &self.db_index {
write!(f, "/{}", db_index)?;
}
Ok(())
}
}

#[derive(Clone, Debug, Educe, serde::Serialize, serde::Deserialize)]
#[educe(PartialEq, Eq, Hash)]
pub struct SqlSource {
Expand All @@ -834,7 +853,7 @@ pub struct SqlSource {
pub value_field: String,
}

#[derive(Clone, Debug, Educe, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, Educe, EnumAsInner, serde::Serialize, serde::Deserialize)]
#[educe(PartialEq, Eq, Hash)]
pub enum DictionarySource {
Mysql(SqlSource),
Expand Down
1 change: 1 addition & 0 deletions tests/sqllogictests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ serde_json = { workspace = true }
sqllogictest = { workspace = true }
sqlparser = { workspace = true }
thiserror = { workspace = true }
threadpool = { workspace = true }
tokio = { workspace = true }
url = { workspace = true }
walkdir = { workspace = true }
Expand Down
31 changes: 25 additions & 6 deletions tests/sqllogictests/src/mock_source/mysql_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,32 @@ use msql_srv::MysqlShim;
use msql_srv::QueryResultWriter;
use msql_srv::StatementMetaWriter;
use mysql_common::Value;
use sqlparser::ast::BinaryOperator;
use sqlparser::ast::Expr;
use sqlparser::ast::SelectItem;
use sqlparser::ast::SetExpr;
use sqlparser::ast::Statement;
use sqlparser::ast::TableFactor;
use sqlparser::dialect::MySqlDialect;
use sqlparser::parser::Parser;
use threadpool::ThreadPool;

pub fn run_mysql_source() {
// Bind the listener to the address
let listener = TcpListener::bind("0.0.0.0:3106").unwrap();

// Create a thread pool
let pool = ThreadPool::new(32);
let backend = Backend::create();

loop {
if let Ok((socket, _)) = listener.accept() {
let backend = backend.clone();
databend_common_base::runtime::Thread::spawn(move || {
MysqlIntermediary::run_on_tcp(backend, socket).unwrap();

pool.execute(move || {
if let Err(e) = MysqlIntermediary::run_on_tcp(backend, socket) {
eprintln!("handle MySQL connection error: {}", e);
}
});
}
}
Expand Down Expand Up @@ -186,11 +194,22 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
if let TableFactor::Table { name, .. } = &select.from[0].relation {
table = Some(name.0[0].value.clone());
}
if let Some(Expr::InList { expr, list, .. }) = &select.selection {
if let Expr::Identifier(ident) = *expr.clone() {
key = Some(ident.value.clone());
in_list_keys.extend(list.clone());
match &select.selection {
Some(Expr::InList { expr, list, .. }) => {
if let Expr::Identifier(ident) = *expr.clone() {
key = Some(ident.value.clone());
in_list_keys.extend(list.clone());
}
}
Some(Expr::BinaryOp { left, op, right }) => {
if op == &BinaryOperator::Eq {
if let Expr::Identifier(ident) = *left.clone() {
key = Some(ident.value.clone());
in_list_keys.push(*right.clone());
}
}
}
_ => {}
}
}
}
Expand Down
Loading

0 comments on commit 644f17b

Please sign in to comment.