Skip to content

Commit

Permalink
add not in
Browse files Browse the repository at this point in the history
  • Loading branch information
liuneng1994 committed Mar 20, 2024
1 parent 45bbef6 commit d2d5f3b
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 19 deletions.
38 changes: 25 additions & 13 deletions src/Analyzer/Passes/ConvertInToEqualPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,28 @@
#include <Analyzer/FunctionNode.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/Passes/ConvertInToEqualPass.h>
#include <Functions/FunctionsComparison.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/equals.h>
#include <Functions/notEquals.h>

namespace DB
{

using FunctionEquals = FunctionComparison<EqualsOp, NameEquals>;

class ConvertInToEqualPassVisitor : public InDepthQueryTreeVisitorWithContext<ConvertInToEqualPassVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<ConvertInToEqualPassVisitor>;
using Base::Base;

FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver()
{
return std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionEquals>(getContext()->getSettings().decimal_check_overflow));
}

void enterImpl(QueryTreeNodePtr & node)
{
static const std::unordered_map<String, String> MAPPING = {
{"in", "equals"},
{"notIn", "notEquals"}
};
auto * func_node = node->as<FunctionNode>();
if (!func_node || func_node->getFunctionName() != "in" || func_node->getArguments().getNodes().size() != 2)
if (!func_node
|| !MAPPING.contains(func_node->getFunctionName())
|| func_node->getArguments().getNodes().size() != 2)
return ;
auto args = func_node->getArguments().getNodes();
auto * column_node = args[0]->as<ColumnNode>();
Expand All @@ -38,13 +37,26 @@ class ConvertInToEqualPassVisitor : public InDepthQueryTreeVisitorWithContext<Co
// x IN null not equivalent to x = null
if (constant_node->hasSourceExpression() || constant_node->getValue().isNull())
return ;
auto equal_resolver = createInternalFunctionEqualOverloadResolver();
auto equal = std::make_shared<FunctionNode>("equals");
auto result_func_name = MAPPING.at(func_node->getFunctionName());
auto equal = std::make_shared<FunctionNode>(result_func_name);
QueryTreeNodes arguments{column_node->clone(), constant_node->clone()};
equal->getArguments().getNodes() = std::move(arguments);
equal->resolveAsFunction(equal_resolver);
FunctionOverloadResolverPtr resolver;
bool decimal_check_overflow = getContext()->getSettingsRef().decimal_check_overflow;
if (result_func_name == "equals")
{
resolver = createInternalFunctionEqualOverloadResolver(decimal_check_overflow);
}
else
{
resolver = createInternalFunctionNotEqualOverloadResolver(decimal_check_overflow);
}
equal->resolveAsFunction(resolver);
node = equal;
}
private:
FunctionOverloadResolverPtr equal_resolver;
FunctionOverloadResolverPtr not_equal_resolver;
};

void ConvertInToEqualPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context)
Expand Down
8 changes: 7 additions & 1 deletion src/Analyzer/Passes/ConvertInToEqualPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
namespace DB
{
/** Optimize `in` to `equals` if possible.
*
* 1. convert in single value to equal
* Example: SELECT * from test where x IN (1);
* Result: SELECT * from test where x = 1;
*
* 2. convert not in single value to notEqual
* Example: SELECT * from test where x NOT IN (1);
* Result: SELECT * from test where x != 1;
*
* If value is null or tuple, do not convert.
*/
class ConvertInToEqualPass final : public IQueryTreePass
{
Expand Down
2 changes: 2 additions & 0 deletions src/Functions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ extract_into_parent_list(clickhouse_functions_sources dbms_sources
multiMatchAny.cpp
checkHyperscanRegexp.cpp
array/has.cpp
equals.cpp
notEquals.cpp
CastOverloadResolver.cpp
)
extract_into_parent_list(clickhouse_functions_headers dbms_headers
Expand Down
5 changes: 5 additions & 0 deletions src/Functions/equals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ REGISTER_FUNCTION(Equals)
factory.registerFunction<FunctionEquals>();
}

FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver(bool decimal_check_overflow)
{
return std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionEquals>(decimal_check_overflow));
}

template <>
ColumnPtr FunctionComparison<EqualsOp, NameEquals>::executeTupleImpl(
const ColumnsWithTypeAndName & x, const ColumnsWithTypeAndName & y, size_t tuple_size, size_t input_rows_count) const
Expand Down
11 changes: 11 additions & 0 deletions src/Functions/equals.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once
#include <memory>

namespace DB
{

class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;

FunctionOverloadResolverPtr createInternalFunctionEqualOverloadResolver(bool decimal_check_overflow);
}
5 changes: 5 additions & 0 deletions src/Functions/notEquals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ REGISTER_FUNCTION(NotEquals)
factory.registerFunction<FunctionNotEquals>();
}

FunctionOverloadResolverPtr createInternalFunctionNotEqualOverloadResolver(bool decimal_check_overflow)
{
return std::make_unique<FunctionToOverloadResolverAdaptor>(std::make_shared<FunctionNotEquals>(decimal_check_overflow));
}

template <>
ColumnPtr FunctionComparison<NotEqualsOp, NameNotEquals>::executeTupleImpl(
const ColumnsWithTypeAndName & x, const ColumnsWithTypeAndName & y, size_t tuple_size, size_t input_rows_count) const
Expand Down
11 changes: 11 additions & 0 deletions src/Functions/notEquals.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once
#include <memory>

namespace DB
{

class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;

FunctionOverloadResolverPtr createInternalFunctionNotEqualOverloadResolver(bool decimal_check_overflow);
}
27 changes: 25 additions & 2 deletions tests/queries/0_stateless/03013_optimize_in_to_equal.reference
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ a 1
0
0
0
0
0
-------------------
QUERY id: 0
PROJECTION COLUMNS
Expand Down Expand Up @@ -41,3 +39,28 @@ QUERY id: 0
COLUMN id: 7, column_name: x, result_type: String, source_id: 3
CONSTANT id: 8, constant_value: Tuple_(\'a\', \'b\'), constant_value_type: Tuple(String, String)
SETTINGS allow_experimental_analyzer=1
-------------------
b 2
c 3
-------------------
QUERY id: 0
PROJECTION COLUMNS
x String
y Int32
PROJECTION
LIST id: 1, nodes: 2
COLUMN id: 2, column_name: x, result_type: String, source_id: 3
COLUMN id: 4, column_name: y, result_type: Int32, source_id: 3
JOIN TREE
TABLE id: 3, alias: __table1, table_name: default.test
WHERE
FUNCTION id: 5, function_name: notEquals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 6, nodes: 2
COLUMN id: 7, column_name: x, result_type: String, source_id: 3
CONSTANT id: 8, constant_value: \'a\', constant_value_type: String
SETTINGS allow_experimental_analyzer=1
-------------------
a 1
b 2
c 3
10 changes: 7 additions & 3 deletions tests/queries/0_stateless/03013_optimize_in_to_equal.sql
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
DROP TABLE IF EXISTS test;
CREATE TABLE test (x String, y Int32) ENGINE = MergeTree() ORDER BY x;

INSERT INTO test VALUES ('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5);

INSERT INTO test VALUES ('a', 1), ('b', 2), ('c', 3);
select * from test where x in ('a') SETTINGS allow_experimental_analyzer = 1;
select '-------------------';
select x in Null from test;
select '-------------------';
explain query tree select * from test where x in ('a') SETTINGS allow_experimental_analyzer = 1;
select '-------------------';
explain query tree select * from test where x in ('a','b') SETTINGS allow_experimental_analyzer = 1;
select '-------------------';
select * from test where x not in ('a') SETTINGS allow_experimental_analyzer = 1;
select '-------------------';
explain query tree select * from test where x not in ('a') SETTINGS allow_experimental_analyzer = 1;
select '-------------------';
select * from test where x not in (NULL) SETTINGS allow_experimental_analyzer = 1;

0 comments on commit d2d5f3b

Please sign in to comment.