diff --git a/src/Analyzer/Passes/CaseWhenSimplifyPass.cpp b/src/Analyzer/Passes/CaseWhenSimplifyPass.cpp index ef17620ea630..5e53062c9acf 100644 --- a/src/Analyzer/Passes/CaseWhenSimplifyPass.cpp +++ b/src/Analyzer/Passes/CaseWhenSimplifyPass.cpp @@ -79,6 +79,11 @@ class ICaseWhenReplaceAlogrithm { chassert(case_node.getFunctionName() == "caseWithExpression"); auto case_args = case_node.getArguments().getNodes(); + if (case_args.size() < 3) + { + failed = true; + return; + } has_else = (case_args.size() - 1) % 2 == 1; case_column = case_args.at(0); for (size_t i = 1; i < case_args.size() - (has_else ? 1 : 0); i += 2) @@ -405,17 +410,17 @@ class CaseWhenSimplifyPassVisitor : public InDepthQueryTreeVisitorWithContext is_null_funcs = {"isNull", "isNotNull"}; + // null property is hard to handle, so abandon the optimization + static const std::unordered_set disabled_parent_funcs = {"isNull", "isNotNull"}; static const std::unordered_set supported_funcs = {"in", "notIn", "equals", "notEquals"}; auto * func_node = node->as(); if (!func_node) return; - if (is_null_funcs.contains(func_node->getFunctionName())) + if (disabled_parent_funcs.contains(func_node->getFunctionName())) { - parentHasIsNull = true; - return; + in_disabled_function = true; } if (!checkFunctionWithArguments(node, supported_funcs, 2)) @@ -448,12 +453,12 @@ class CaseWhenSimplifyPassVisitor : public InDepthQueryTreeVisitorWithContextgetResultType()->equals(*node->getResultType())) node = new_node; } private: - bool parentHasIsNull = false; + bool in_disabled_function = false; }; }